import sys
sys.path.append("./mainldm")
sys.path.append("./mainddpm")
sys.path.append('./src/taming-transformers')
sys.path.append('.')
print(sys.path)
import argparse
import os, gc
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import time
import logging
import wandb
import numpy as np
import torch.distributed as dist

import torch
# torch.set_grad_enabled(False)
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler, DDIMSampler_trainer
from imwatermark import WatermarkEncoder
from PIL import Image
from einops import rearrange
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from evalution.sfid import test_fid_sfid
from quant.utils import AttentionMap, seed_everything, Fisher 
from quant.quant_model import QModel
from quant.quant_block import Change_LDM_model_SpatialTransformer
from quant.set_quantize_params import set_act_quantize_params_cond, set_weight_quantize_params_cond
from quant.recon_Qmodel import recon_Qmodel, skip_Model
from quant.quant_layer import QuantModule
from quant.adaptive_rounding import AdaRoundQuantizer
logger = logging.getLogger(__name__)


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("./mainldm/configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "./models/ldm/cin256/model.ckpt")
    return model


def put_watermark(img, wm_encoder=None):
    if wm_encoder is not None:
        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        img = wm_encoder.encode(img, 'dwtDct')
        img = Image.fromarray(img[:, :, ::-1])
    return img


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_samples', type=int, default=50000)
    parser.add_argument('--sample_batch', type=int, default=50)
    parser.add_argument("--local_rank", type=int, default=1)
    parser.add_argument("--scale", type=float, default=1.5)
    parser.add_argument('--ddim_steps', type=int, default=250)
    parser.add_argument("--ddim_eta", type=float, default=0.0)
    parser.add_argument('--seed', type=int, default=1234+9)

    parser.add_argument("--replicate_interval", type=int, default=20)
    parser.add_argument("--sm_abit",type=int, default=8)
    parser.add_argument("--quant_act", action="store_true", default=True)
    parser.add_argument("--weight_bit",type=int,default=8)
    parser.add_argument("--act_bit",type=int,default=8)
    parser.add_argument("--quant_mode", type=str, default="qdiff", choices=["qdiff"])
    parser.add_argument("--split", action="store_true", default=True)
    parser.add_argument("--my_steps", action='store_true', default=True)
    parser.add_argument("--ptq", action="store_true", default=False)

    parser.add_argument("--nonuniform", action='store_true')
    parser.add_argument("--pow", type=float, default=1.5)
    args = parser.parse_args()
    if args.my_steps:
        args.mode = "my_opt"
    else:
        args.mode = "uni"
    seed_everything(args.seed)
    # torch.set_grad_enabled(False)
    device = torch.device("cuda", args.local_rank)

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
        handlers=[
            logging.FileHandler("./output/run.log"),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)
    logging.info(args)
    logger.info("./pretraining/imageNet{}_cache{}_{}.pth".format(args.ddim_steps, args.replicate_interval, args.mode))
    interval_seq, all_cali_data, all_t, all_cond, all_uncond, all_cali_t, all_cache1, all_cache2 \
        = torch.load("./pretraining/imageNet{}_cache{}_{}.pth".format(args.ddim_steps, args.replicate_interval, args.mode))
    logging.info(interval_seq)
    args.interval_seq = interval_seq
    model = get_model()

    if args.ptq:
        wq_params = {'n_bits': args.weight_bit, 'symmetric': False, 'channel_wise': True, 'scale_method': 'max'}
        aq_params = {'n_bits': args.act_bit, 'symmetric': False, 'channel_wise': False, 'scale_method': 'mse', 'leaf_param': args.quant_act, "prob": 1.0, "num_timesteps": args.ddim_steps}
        q_unet = QModel(model.model.diffusion_model, args, wq_params=wq_params, aq_params=aq_params)
        q_unet.cuda()
        q_unet.eval()

        logging.info("Setting the first and the last layer to 8-bit")
        q_unet.set_first_last_layer_to_8bit()
        q_unet.set_quant_state(False, False)

        if args.split:
            q_unet.model.split_shortcut = True
        
        cali_data = [torch.cat([cali_data] * 2) for cali_data in all_cali_data]
        t = [torch.cat([t] * 2) for t in all_t]
        context = [torch.cat([all_uncond[i], all_cond[i]]) for i in range(len(all_cond))]

        cali_data = torch.cat(cali_data)
        t = torch.cat(t)
        context = torch.cat(context)
        idx = torch.randperm(len(cali_data))[:32]
        cali_data = cali_data[idx]
        t = t[idx]
        context = context[idx]

        set_weight_quantize_params_Imagenet(args, q_unet, cali_data=(cali_data, t, context))
        del (cali_data, t, context)
        gc.collect()
        set_act_quantize_params_Imagenet(args, q_unet, [all_cali_data[0]], all_t, all_cond, all_uncond, all_cache1, all_cache2)
        del (all_cali_data, all_t, all_cond, all_uncond, all_cali_t, all_cache1, all_cache2)
        gc.collect()
        q_unet.set_quant_state(True, True)

        round_mode = 'learned_hard_sigmoid'
        for module in q_unet.modules():
            if isinstance(module, QuantModule):
                if module.split == 0:
                    if isinstance(module.weight_quantizer, AdaRoundQuantizer)==False:
                        module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
                                                                    weight_tensor=module.org_weight.data)
                else :
                    if isinstance(module.weight_quantizer, AdaRoundQuantizer)==False:
                        module.weight_quantizer = AdaRoundQuantizer(uaq=module.weight_quantizer, round_mode=round_mode,
                                                                    weight_tensor=module.org_weight.data[:, :module.split, ...])
                    if isinstance(module.weight_quantizer_0, AdaRoundQuantizer)==False:
                        module.weight_quantizer_0 = AdaRoundQuantizer(uaq=module.weight_quantizer_0, round_mode=round_mode,
                                                                    weight_tensor=module.org_weight.data[:, module.split:, ...])

        setattr(model.model, 'diffusion_model', q_unet)

        checkpoint = torch.load("./pretraining/ImageNet{}_int{}_cache{}_{}_trained.pth".format(args.ddim_steps, args.weight_bit, args.replicate_interval, args.mode))
        model.load_state_dict(checkpoint)
        del checkpoint

        sampler = DDIMSampler(model, slow_steps=args.interval_seq)
        model.model.reset_no_cache(no_cache=False)
        sampler.quant_sample = True
    else:
        del (all_cali_data, all_t, all_cond, all_uncond, all_cali_t, all_cache1, all_cache2)
        model.model.reset_no_cache(no_cache=False)
        sampler = DDIMSampler(model, slow_steps=args.interval_seq)

    imglogdir = "./mainldm/imagenet/"
    base_count = 0
    wm = "StableDiffusionV1"
    wm_encoder = WatermarkEncoder()
    wm_encoder.set_watermark('bytes', wm.encode('utf-8'))

    logging.info("sampling...")
    seed_everything(args.seed)
    # model.first_stage_model.quantize.cpu()
    iterator = tqdm(range(1000), desc='DDIM Sampler')
    with torch.no_grad():
        with model.ema_scope():
            uc = model.get_learned_conditioning(
                {model.cond_stage_key: torch.tensor(args.sample_batch*[1000]).to(model.device)}
                )
            for i, class_num in enumerate(iterator):
                class_label = class_num
                xc = torch.tensor(args.sample_batch*[class_label])
                c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
                
                samples_ddim, _ = sampler.sample(S=args.ddim_steps,
                                                conditioning=c,
                                                batch_size=args.sample_batch,
                                                shape=[3, 64, 64],
                                                verbose=False,
                                                unconditional_guidance_scale=args.scale,
                                                unconditional_conditioning=uc, 
                                                eta=args.ddim_eta,
                                                replicate_interval=args.replicate_interval,
                                                nonuniform=args.nonuniform, pow=args.pow)

                # x_samples_ddim = model.decode_first_stage(samples_ddim.cpu())
                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                # all_samples.append(x_samples_ddim)
                x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()

                x_checked_image = x_samples_ddim
                # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)

                x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                for x_sample in x_checked_image_torch:
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                    img = Image.fromarray(x_sample.astype(np.uint8))
                    img = put_watermark(img, wm_encoder)
                    img.save(os.path.join(imglogdir, f"{base_count:05}.png"))
                    base_count += 1
                    if base_count == args.num_samples:
                        break
                if base_count == args.num_samples:
                    break
