from tqdm import tqdm
from PIL import Image
import numpy as np
import diffusers
import argparse
import torch

from utils import load_model, load_resize_image


device = torch.device("cuda")


# ==============================================================================
# from https://github.com/MichalGeyer/pnp-diffusers
# ==============================================================================

def register_time(model, t):
    conv_module = model.unet.up_blocks[1].resnets[1]
    setattr(conv_module, 't', t)
    down_res_dict = {0: [0, 1], 1: [0, 1], 2: [0, 1]}
    up_res_dict = {1: [0, 1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    for res in up_res_dict:
        for block in up_res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            setattr(module, 't', t)
    for res in down_res_dict:
        for block in down_res_dict[res]:
            module = model.unet.down_blocks[res].attentions[block].transformer_blocks[0].attn1
            setattr(module, 't', t)
    module = model.unet.mid_block.attentions[0].transformer_blocks[0].attn1
    setattr(module, 't', t)


def register_attention_control_efficient(model, injection_schedule):
    def sa_forward(self):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out

        def forward(x, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads

            is_cross = encoder_hidden_states is not None
            encoder_hidden_states = encoder_hidden_states if is_cross else x
            if not is_cross and self.injection_schedule is not None and (
                    self.t in self.injection_schedule or self.t == 1000):
                q = self.to_q(x)
                k = self.to_k(encoder_hidden_states)

                source_batch_size = int(q.shape[0] // 3)
                # inject unconditional
                q[source_batch_size:2 * source_batch_size] = q[:source_batch_size]
                k[source_batch_size:2 * source_batch_size] = k[:source_batch_size]
                # inject conditional
                q[2 * source_batch_size:] = q[:source_batch_size]
                k[2 * source_batch_size:] = k[:source_batch_size]

                q = self.head_to_batch_dim(q)
                k = self.head_to_batch_dim(k)
            else:
                q = self.to_q(x)
                k = self.to_k(encoder_hidden_states)
                q = self.head_to_batch_dim(q)
                k = self.head_to_batch_dim(k)

            v = self.to_v(encoder_hidden_states)
            v = self.head_to_batch_dim(v)

            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

            if attention_mask is not None:
                attention_mask = attention_mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
                sim.masked_fill_(~attention_mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            out = torch.einsum("b i j, b j d -> b i d", attn, v)
            out = self.batch_to_head_dim(out)

            return to_out(out)

        return forward

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}  # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)
            setattr(module, 'injection_schedule', injection_schedule)


def register_conv_control_efficient(model, injection_schedule):
    def conv_forward(self):
        def forward(input_tensor, temb):
            hidden_states = input_tensor

            hidden_states = self.norm1(hidden_states)
            hidden_states = self.nonlinearity(hidden_states)

            if self.upsample is not None:
                # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
                if hidden_states.shape[0] >= 64:
                    input_tensor = input_tensor.contiguous()
                    hidden_states = hidden_states.contiguous()
                input_tensor = self.upsample(input_tensor)
                hidden_states = self.upsample(hidden_states)
            elif self.downsample is not None:
                input_tensor = self.downsample(input_tensor)
                hidden_states = self.downsample(hidden_states)

            hidden_states = self.conv1(hidden_states)

            if temb is not None:
                temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]

            if temb is not None and self.time_embedding_norm == "default":
                hidden_states = hidden_states + temb

            hidden_states = self.norm2(hidden_states)

            if temb is not None and self.time_embedding_norm == "scale_shift":
                scale, shift = torch.chunk(temb, 2, dim=1)
                hidden_states = hidden_states * (1 + scale) + shift

            hidden_states = self.nonlinearity(hidden_states)

            hidden_states = self.dropout(hidden_states)
            hidden_states = self.conv2(hidden_states)
            if self.injection_schedule is not None and (self.t in self.injection_schedule or self.t == 1000):
                source_batch_size = int(hidden_states.shape[0] // 3)
                # inject unconditional
                hidden_states[source_batch_size:2 * source_batch_size] = hidden_states[:source_batch_size]
                # inject conditional
                hidden_states[2 * source_batch_size:] = hidden_states[:source_batch_size]

            if self.conv_shortcut is not None:
                input_tensor = self.conv_shortcut(input_tensor)

            output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

            return output_tensor

        return forward

    conv_module = model.unet.up_blocks[1].resnets[1]
    conv_module.forward = conv_forward(conv_module)
    setattr(conv_module, 'injection_schedule', injection_schedule)


@torch.no_grad()
def ddim_inversion(model, cond, latent):
    timesteps = reversed(model.scheduler.timesteps)
    latents = {}

    for i, t in enumerate(tqdm(timesteps)):
        cond_batch = cond.repeat(latent.shape[0], 1, 1)

        alpha_prod_t = model.scheduler.alphas_cumprod[t]
        alpha_prod_t_prev = (
            model.scheduler.alphas_cumprod[timesteps[i - 1]]
            if i > 0 else model.scheduler.final_alpha_cumprod
        )

        mu = alpha_prod_t ** 0.5
        mu_prev = alpha_prod_t_prev ** 0.5
        sigma = (1 - alpha_prod_t) ** 0.5
        sigma_prev = (1 - alpha_prod_t_prev) ** 0.5

        eps = model.unet(latent, t, encoder_hidden_states=cond_batch).sample

        pred_x0 = (latent - sigma_prev * eps) / mu_prev
        latent = mu * pred_x0 + sigma * eps

        latents[t.item()] = latent

    return latent, latents


@torch.no_grad()
def denoise_step(model, x, t, text_embed_input, latents: dict, guidance_scale: float = 7.5):
    # register the time step and features in pnp injection modules
    source_latents = latents[t.item()]
    latent_model_input = torch.cat([source_latents] + ([x] * 2))

    register_time(model, t.item())

    # apply the denoising network
    noise_pred = model.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample']

    # perform guidance
    _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

    # compute the denoising step with the reference model
    denoised_latent = model.scheduler.step(noise_pred, t, x)['prev_sample']
    return denoised_latent


def init_pnp(model, conv_injection_t, qk_injection_t):
    qk_injection_timesteps = model.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
    conv_injection_timesteps = model.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []

    register_attention_control_efficient(model, qk_injection_timesteps)
    register_conv_control_efficient(model, conv_injection_timesteps)


@torch.no_grad()
def run_pnp(
    model, image_pil: Image.Image, prompt: str, num_inference_steps: int = 50, guidance_scale: float = 7.5,
    negative_prompt: str = "", pnp_f_t: float = 1.0, pnp_attn_t: float = 1.0
):
    # Inversion
    model.scheduler.set_timesteps(999, device=device)

    cond = model._encode_prompt(negative_prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
    image = torch.from_numpy(np.array(image_pil) / 127.5 - 1).float().permute(2, 0, 1).unsqueeze(0).to(device)
    latent = 0.18215 * model.vae.encode(image)["latent_dist"].mean

    x_T, latents = ddim_inversion(model, cond, latent)

    # Sampling
    model.scheduler.set_timesteps(num_inference_steps, device=device)

    text_embeds = model._encode_prompt(prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
    pnp_guidance_embeds = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
    text_embed_input = torch.cat([pnp_guidance_embeds, text_embeds], dim=0)

    pnp_f_t = int(num_inference_steps * pnp_f_t)
    pnp_attn_t = int(num_inference_steps * pnp_attn_t)
    init_pnp(model, conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)

    x = x_T
    for i, t in enumerate(tqdm(model.scheduler.timesteps)):
        x = denoise_step(model, x, t, text_embed_input, latents, guidance_scale)

    edited_image = model.vae.decode(1 / 0.18215 * x)['sample']
    edited_image = (edited_image / 2 + 0.5).clamp(0, 1)
    edited_image = edited_image.cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
    edited_image = Image.fromarray(np.uint8(edited_image * 255))

    return edited_image


def main(args):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    # Load model
    model = load_model()

    # Parameter setting
    print("Prompt: ", args.prompt)
    print("Edited Prompt: ", args.edited)

    pnp_f_t = 0.8
    pnp_attn_t = 0.5

    # Load image
    input_image = load_resize_image(args.image)

    # Inversion
    if args.npi: # negative-prompt inversion
        negative_prompt = args.prompt
    else:
        negative_prompt = ""

    edited_image = run_pnp(model, input_image, args.edited, args.step, args.cfg, negative_prompt, pnp_f_t, pnp_attn_t)
    edited_image.save("output_pnp.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("image2image using Plug-and-Play")

    parser.add_argument("--image", type=str, help="Path of input image")
    parser.add_argument("--prompt", type=str, help="Input prompt")
    parser.add_argument("--edited", type=str, help="Edited prompt")

    parser.add_argument("--step", type=int, default=50, help="Number of steps to generate")
    parser.add_argument("--cfg", type=float, default=7.5, help="Classifier-free Guidance scale")
    parser.add_argument("--npi", action="store_true", help="Use negative-prompt inversion")

    args = parser.parse_args()

    if diffusers.__version__ != "0.17.1":
        import warnings
        warnings.warn("In certain version of  'diffusers', this script may encounter errors due to modifications in keyword arguments. Please update 'diffusers' to version 0.17.1.")

    main(args)