from tqdm import tqdm
from PIL import Image
import numpy as np
import argparse
import torch

from utils import load_model, load_resize_image, negative_prompt_inversion, generate
from null_text_inversion import NullInversion
import ptp_utils

device = torch.device("cuda")

@torch.no_grad()
def sdedit(
    model, image_pil, prompt: str, latent = None, uncond_embeddings = None, num_ddim_steps: int = 50, guidance_scale: float = 7.5, strength: float = 1.0, original_prompt: str = None,
):
    height = width = 512

    text_embeddings = model._encode_prompt(prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
    if uncond_embeddings is None:
        uncond_embeddings_ = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
    else:
        uncond_embeddings_ = None

    model.scheduler.set_timesteps(num_ddim_steps)
    T = model.scheduler.config.num_train_timesteps
    t0_idx = int(num_ddim_steps * strength)
    t0 = model.scheduler.timesteps[-t0_idx]

    if latent is None:
        image = torch.from_numpy(np.array(image_pil) / 127.5 - 1).float().permute(2, 0, 1).unsqueeze(0).to(device)
        latent = 0.18215 * model.vae.encode(image)["latent_dist"].mean

        # Add noise
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        noise = torch.randn_like(latent)
        alpha_prod_t = model.scheduler.alphas_cumprod[t0]
        latent = (alpha_prod_t)**0.5 * latent + (1-alpha_prod_t)**0.5 * noise
    else:
        for i, t in enumerate(tqdm(model.scheduler.timesteps[:-t0_idx])):
            # Inverse DDIM inversion
            context = model._encode_prompt(original_prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
            model_output = model.unet(latent, t, encoder_hidden_states=context).sample
            latent = model.scheduler.step(model_output, t, latent).prev_sample

    # Sampling with edited prompt
    for i, t in enumerate(tqdm(model.scheduler.timesteps[-t0_idx:])):
        if uncond_embeddings_ is None:
            context = torch.cat([uncond_embeddings[i+num_ddim_steps-t0_idx], text_embeddings])
        else:
            context = torch.cat([uncond_embeddings_, text_embeddings])

        model_output = model.unet(latent.repeat(2,1,1,1), t, encoder_hidden_states=context).sample
        pred_uncond, pred_text = model_output.chunk(2)
        model_output = pred_uncond + guidance_scale * (pred_text - pred_uncond)

        latent = model.scheduler.step(model_output, t, latent).prev_sample

    image = ptp_utils.latent2image(model.vae, latent)
    return Image.fromarray(image[-1])


def main(args):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    model  = load_model()

    print("Prompt: ", args.prompt)
    print("Edited Prompt: ", args.edited)

    # Load image
    input_image = load_resize_image(args.image)

    # Inversion
    if args.nti: # null-text inversion
        null_inversion = NullInversion(model, args.step, args.cfg)
        x_t, uncond_embeds = null_inversion.invert(input_image, args.prompt)

        edited_image = sdedit(model, input_image, args.edited, x_t, uncond_embeds, args.step, args.cfg, args.t0, args.prompt)
    elif args.npi: # negative-prompt inversion
        x_t, uncond_embed = negative_prompt_inversion(model, input_image, args.prompt, args.step)
        uncond_embeds = [uncond_embed] * args.step

        edited_image = sdedit(model, input_image, args.edited, x_t, uncond_embeds, args.step, args.cfg, args.t0, args.prompt)
    else: # only DDIM inversion
        x_t, uncond_embed = negative_prompt_inversion(model, input_image, args.prompt, args.step)
        with torch.no_grad():
            uncond_embed = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None).detach()

        uncond_embeds = [uncond_embed] * args.step

        edited_image = sdedit(model, input_image, args.edited, None, uncond_embeds, args.step, args.cfg, args.t0)

    edited_image.save("output_sde.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("image2image using SDEdit")

    parser.add_argument("--image", type=str, help="Path of input image")
    parser.add_argument("--prompt", type=str, help="Input prompt")
    parser.add_argument("--edited", type=str, help="Edited prompt")

    parser.add_argument("--t0", type=float, default=0.7, help="t0, Parameter for sdedit")

    parser.add_argument("--step", type=int, default=50, help="Number of steps to generate")
    parser.add_argument("--cfg", type=float, default=7.5, help="Classifier-free Guidance scale")
    parser.add_argument("--npi", action="store_true", help="Use negative-prompt inversion")
    parser.add_argument("--nti", action="store_true", help="Use null-text inversion")

    args = parser.parse_args()

    if args.npi and args.nti:
        print("Only one of '--npi' and '--nti' can be used")
        sys.quit()

    main(args)
