from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
import torch
from diffusers import StableDiffusionPipeline
from transformers import (
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPTokenizer,
)
from dbst.lora import load_lora, train_lora
from dbst.utils import slerp


class DBSTPipeline(StableDiffusionPipeline):

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = True,
    ):
        super().__init__(
            vae,
            text_encoder,
            tokenizer,
            unet,
            scheduler,
            safety_checker,
            feature_extractor,
            requires_safety_checker,
        )

        self.img0_dict = dict()
        self.img1_dict = dict()

    def train_lora(
        self,
        image: torch.Tensor,
        prompt: str,
        save_lora_path: str,
        lora_steps=200,
        lora_lr=2e-4,
        lora_rank=16,
    ):
        train_lora(
            image,
            prompt,
            save_lora_path,
            tokenizer=self.tokenizer,
            text_encoder=self.text_encoder,
            vae=self.vae,
            unet=self.unet,
            noise_scheduler=self.scheduler,
            lora_steps=lora_steps,
            lora_lr=lora_lr,
            lora_rank=lora_rank,
        )

    @torch.no_grad()
    def __call__(
        self,
        img_0: torch.Tensor,
        img_1: torch.Tensor,
        prompt_0: str,
        prompt_1: str,
        lora_path_0: str,
        lora_path_1: str,
        num_frames: int,
        num_inference_steps: int,
        alpha_list: list[float] = None,
    ):
        self.scheduler.set_timesteps(num_inference_steps)

        lora_0 = torch.load(lora_path_0)
        lora_1 = torch.load(lora_path_1)

        text_embeddings_0 = self.get_text_embeddings(prompt_0)
        text_embeddings_1 = self.get_text_embeddings(prompt_1)

        self.unet = load_lora(self.unet, lora_0, lora_1, 0)
        img_noise_0 = self.ddim_inversion(self.image2latent(img_0), text_embeddings_0)
        self.unet = load_lora(self.unet, lora_0, lora_1, 1)
        img_noise_1 = self.ddim_inversion(self.image2latent(img_1), text_embeddings_1)

        def morph(alpha_list):
            images = [img_0]
            for i in range(1, num_frames - 1):
                alpha = alpha_list[i]
                self.unet = load_lora(self.unet, lora_0, lora_1, alpha)

                attn_processor_dict = {}
                for k in self.unet.attn_processors.keys():
                    if k.startswith("up"):
                        attn_processor_dict[k] = LoadProcessor(
                            self.unet.attn_processors[k],
                            k,
                            self.img0_dict,
                            self.img1_dict,
                            alpha,
                        )
                    else:
                        attn_processor_dict[k] = self.unet.attn_processors[k]

                self.unet.set_attn_processor(attn_processor_dict)

                latents = self.cal_latent(
                    num_inference_steps,
                    img_noise_0,
                    img_noise_1,
                    text_embeddings_0,
                    text_embeddings_1,
                    lora_0,
                    lora_1,
                    alpha_list[i],
                )
                image = self.latent2image(latents)
                images.append(image)
            images.append(img_1)
            return torch.cat(images)

        with torch.no_grad():
            images = morph(alpha_list)
        return images

    @torch.no_grad()
    def get_text_embeddings(self, prompt: str):
        text_input = self.tokenizer(
            prompt, padding="max_length", max_length=77, return_tensors="pt"
        )
        text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
        return text_embeddings

    @torch.no_grad()
    def ddim_inversion(self, latent, cond):
        timesteps = reversed(self.scheduler.timesteps)
        for i, t in enumerate(timesteps):
            cond_batch = cond.repeat(latent.shape[0], 1, 1)

            alpha_prod_t = self.scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = (
                self.scheduler.alphas_cumprod[timesteps[i - 1]]
                if i > 0
                else self.scheduler.final_alpha_cumprod
            )

            mu = alpha_prod_t**0.5
            mu_prev = alpha_prod_t_prev**0.5
            sigma = (1 - alpha_prod_t) ** 0.5
            sigma_prev = (1 - alpha_prod_t_prev) ** 0.5

            eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample

            pred_x0 = (latent - sigma_prev * eps) / mu_prev
            latent = mu * pred_x0 + sigma * eps
        return latent

    @torch.no_grad()
    def image2latent(self, image: torch.Tensor):
        image = image * 2 - 1
        latents = self.vae.encode(image)["latent_dist"].mean
        latents = latents * 0.18215
        return latents

    @torch.no_grad()
    def latent2image(self, latents: torch.Tensor):
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents)["sample"]
        image = (image / 2 + 0.5).clamp(0, 1)

        return image

    @torch.no_grad()
    def cal_latent(
        self,
        num_inference_steps,
        img_noise_0,
        img_noise_1,
        text_embeddings_0,
        text_embeddings_1,
        lora_0,
        lora_1,
        alpha,
    ):
        latents = slerp(img_noise_0, img_noise_1, alpha, adain=True)
        text_embeddings = (1 - alpha) * text_embeddings_0 + alpha * text_embeddings_1

        self.scheduler.set_timesteps(num_inference_steps)
        self.unet = load_lora(self.unet, lora_0, lora_1, alpha)

        for i, t in enumerate(self.scheduler.timesteps):
            model_inputs = latents
            noise_pred = self.unet(
                model_inputs, t, encoder_hidden_states=text_embeddings
            ).sample
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
        return latents


class StoreProcessor:
    def __init__(self, original_processor, value_dict, name):
        self.original_processor = original_processor
        self.value_dict = value_dict
        self.name = name
        self.value_dict[self.name] = dict()
        self.id = 0

    def __call__(
        self,
        attn,
        hidden_states,
        *args,
        encoder_hidden_states=None,
        attention_mask=None,
        **kwargs,
    ):
        if encoder_hidden_states is None:
            self.value_dict[self.name][self.id] = hidden_states.detach()
            self.id += 1
        res = self.original_processor(
            attn,
            hidden_states,
            *args,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            **kwargs,
        )

        return res


class LoadProcessor:
    def __init__(self, original_processor, name, img0_dict, img1_dict, alpha):
        super().__init__()
        self.original_processor = original_processor
        self.name = name
        self.img0_dict = img0_dict
        self.img1_dict = img1_dict
        self.alpha = alpha
        self.id = 0

    def __call__(
        self,
        attn,
        hidden_states,
        *args,
        encoder_hidden_states=None,
        attention_mask=None,
        **kwargs,
    ):
        if encoder_hidden_states is None:
            res = self.original_processor(
                attn,
                hidden_states,
                *args,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                **kwargs,
            )

            self.id += 1
            if self.id == len(self.img0_dict[self.name]):
                self.id = 0
        else:
            res = self.original_processor(
                attn,
                hidden_states,
                *args,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                **kwargs,
            )

        return res
