from tqdm import tqdm
from PIL import Image
import argparse
import torch

from utils import load_model, load_resize_image, negative_prompt_inversion, generate
from null_text_inversion import NullInversion
import ptp_functions
import ptp_utils

device = torch.device("cuda")


@torch.no_grad()
def text2image_ldm_stable(
    model, prompt, controller, num_inference_steps: int = 50, guidance_scale: float = 7.5, latents = None, uncond_embeddings = None,
):
    batch_size = len(prompt)
    ptp_utils.register_attention_control(model, controller)
    height = width = 512

    text_embeddings = model._encode_prompt(prompt, device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)

    if uncond_embeddings is None:
        uncond_embeddings_ = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None)
    else:
        uncond_embeddings_ = None

    latents = latents.repeat(batch_size, 1, 1, 1)
    model.scheduler.set_timesteps(num_inference_steps)

    for i, t in enumerate(tqdm(model.scheduler.timesteps)):
        if uncond_embeddings_ is None:
            context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
        else:
            context = torch.cat([uncond_embeddings_, text_embeddings])

        latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)

    image = ptp_utils.latent2image(model.vae, latents)
    return Image.fromarray(image[-1])


def main(args):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    model  = load_model()

    # Parameter setting
    cross_replace_steps = {'default_': args.cross_replace}
    self_replace_steps = args.self_replace
    if args.blend_words is None:
        blend_words = None
    else:
        blend_words = (((args.blend_words[0],), (args.blend_words[1],)))
    eq_params = {"words": tuple(args.eq_word), "values": tuple(args.eq_value)}

    prompts = [
        args.prompt,
        args.edited,
    ]
    print("Prompt: ", args.prompt)
    print("Edited Prompt: ", args.edited)

    controller = ptp_functions.make_controller(model, prompts, args.step, args.replace, cross_replace_steps, self_replace_steps, blend_words, eq_params)

    # Load image
    input_image = load_resize_image(args.image)

    # Inversion
    if args.nti: # null-text inversion
        null_inversion = NullInversion(model, args.step, args.cfg)
        x_T, uncond_embeds = null_inversion.invert(input_image, args.prompt)
    else:
        # negative-prompt inversion
        x_T, uncond_embed = negative_prompt_inversion(model, input_image, args.prompt, args.step)
        if not args.npi: # only DDIM inversion
            with torch.no_grad():
                uncond_embed = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None).detach()

        uncond_embeds = [uncond_embed] * args.step

    edited_image = text2image_ldm_stable(model, prompts, controller, args.step, args.cfg, latents=x_T, uncond_embeddings=uncond_embeds)
    edited_image.save("output_p2p.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser("image2image using Prompt-to-Prompt")

    parser.add_argument("--image", type=str, help="Path of input image")
    parser.add_argument("--prompt", type=str, help="Input prompt")
    parser.add_argument("--edited", type=str, help="Edited prompt")

    parser.add_argument("--replace", action="store_true", help="Replace objects")
    parser.add_argument("--cross_replace", type=float, default=0.8, help="cross_replace_steps, Parameter for prompt-to-prompt")
    parser.add_argument("--self_replace", type=float, default=0.3, help="self_replace_steps, Parameter for prompt-to-prompt")
    parser.add_argument("--blend_words", type=str, nargs=2, default=None, help="blend_word, Parameter for prompt-to-prompt")
    parser.add_argument("--eq_word", type=str, nargs="+", help="words of eq_params, Parameter for prompt-to-prompt")
    parser.add_argument("--eq_value", type=float, nargs="+", help="values of eq_params, Parameter for prompt-to-prompt")

    parser.add_argument("--step", type=int, default=50, help="Number of steps to generate")
    parser.add_argument("--cfg", type=float, default=7.5, help="Classifier-free Guidance scale")
    parser.add_argument("--npi", action="store_true", help="Use negative-prompt inversion")
    parser.add_argument("--nti", action="store_true", help="Use null-text inversion")

    args = parser.parse_args()

    if args.npi and args.nti:
        print("Only one of '--npi' and '--nti' can be used")
        sys.quit()

    main(args)
