import os
import os.path as osp
from collections import OrderedDict
import math,copy
import torch.nn as nn
from datasets import Action_DATASETS
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
import shutil
from pathlib import Path
import yaml
from dotmap import DotMap
import pprint
import time


import matplotlib.pyplot as plt
import numpy as np
import math
from utils.KLLoss import *
from test import validate
from utils.Augmentation import *
from utils.solver import _optimizer, _lr_scheduler
from utils.tools import *
from utils.saving import *
import csv

import clip 
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        combined = [x, compound_prompts_deeper_text, 0]  
        outputs = self.transformer(combined)
        x = outputs[0]  # extract the x back from here
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x
    
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
    
class MultiModalPromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.mm_prompt.N_CTX
        ctx_init = cfg.mm_prompt.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.data.input_size
        self.compound_prompts_depth = cfg.mm_prompt.PROMPT_DEPTH 
        if ctx_init and (n_ctx) <= 4:
            ctx_init = ctx_init.replace("_", " ")
            ctx_init = ctx_init.replace("-", " ")
            n_ctx = n_ctx
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)
        self.proj = nn.Linear(ctx_dim, 768)
        self.proj.half()
        self.ctx = nn.Parameter(ctx_vectors)

        self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
                                                      for _ in range(self.compound_prompts_depth - 1)])
        for single_para in self.compound_prompts_text:
            nn.init.normal_(single_para, std=0.02)
        single_layer = nn.Linear(ctx_dim, 768)
        self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])  # (n_cls, n_tkn)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens

    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]
            
        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,  # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )

        return prompts

    def forward(self):
        ctx = self.ctx

        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix
        prompts = self.construct_prompts(ctx, prefix, suffix)

        visual_deep_prompts = []
        for index, layer in enumerate(self.compound_prompt_projections):
            visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
        return prompts, self.proj(self.ctx), self.compound_prompts_text, visual_deep_prompts   



class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image_features, text_features, label=None):
        logit_scale = self.logit_scale.exp()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logits = logit_scale * image_features @ text_features.t()
        if self.prompt_learner.training:
            return F.cross_entropy(logits, label)

        return logits
    
    def encode(self, image):
        tokenized_prompts = self.tokenized_prompts
            
        prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
        text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
        image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
        return  image_features, text_features
    



def print_time(seconds):
    seconds = seconds % (24 * 3600)
    hour = seconds // 3600
    seconds %= 3600
    minutes = seconds // 60
    seconds %= 60
    return "%d:%02d:%02d" % (hour, minutes, seconds)


def main():
    global args, best_prec1
    global global_step
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-cfg", default="")
    parser.add_argument("--traning_name", default="")
    parser.add_argument("--few_shot", default=0)
    args = parser.parse_args()
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)
    if args.few_shot != 0:
        config["training_name"] = config["training_name"] + "_" + args.few_shot
        config["data"]["train_list"] = config["data"]["train_list"].replace("train","train_"+args.few_shot)
    working_dir = os.path.join(
        config["weight_save_dir"],
        config["network"]["type"],
        config["network"]["arch"],
        config["data"]["dataset"],
        config["training_name"],
    )
    print("-" * 80)
    print(" " * 20, "working dir: {}".format(working_dir))
    print("-" * 80)

    print("-" * 80)
    print(" " * 30, "Config")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(config)
    print("-" * 80)

    config = DotMap(config)

    Path(working_dir).mkdir(parents=True, exist_ok=True)
    shutil.copy(args.config, working_dir)

    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )  

    design_details = {
                    "vision_depth": 0,
                    "language_depth": 0, 
                    "vision_ctx": 0,
                    "language_ctx": 0
                    }
    model, clip_state_dict = clip.load(
        config.network.arch,
        config,
        device=torch.device('cpu'),
        jit=False,
        tsm=config.network.tsm,
        T=config.data.num_segments,
        dropout=config.network.drop_out,
        emb_dropout=config.network.emb_dropout,
        pretrain=config.network.init,
        joint=config.network.joint,
        design_details=design_details
    ) 
    transform_train = get_augmentation(True, config)
    transform_val = get_augmentation(False, config)
    
    if config.data.randaug.N > 0:
        transform_train = randAugment(transform_train, config)

    print("train transforms: {}".format(transform_train.transforms))
    print("val transforms: {}".format(transform_val.transforms))
    ############################## dataset  loader ###################################
    train_data = Action_DATASETS(
        config.data.train_list,
        config.data.label_list,
        num_segments=config.data.num_segments,
        image_tmpl=config.data.image_tmpl,
        random_shift=config.data.random_shift,
        transform=transform_train,
    )
    train_loader = DataLoader(
        train_data,
        batch_size=config.data.batch_size,
        num_workers=config.data.workers,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )

    val_data = Action_DATASETS(
        config.data.val_list,
        config.data.label_list,
        random_shift=False,
        num_segments=config.data.num_segments,
        image_tmpl=config.data.image_tmpl,
        transform=transform_val,
    )
    val_loader = DataLoader(
        val_data,
        batch_size=config.data.batch_size,
        num_workers=config.data.workers,
        shuffle=False,
        pin_memory=False,
        drop_last=True,
    )

    classnames = [name for id, name in train_data.classes]
    customCLIP = CustomCLIP(config, classnames, model).to(device)
    for name, param in customCLIP.named_parameters():
        if "prompt_learner" not in name and "prompt" not in name and "Adapter" not in name: 
            param.requires_grad_(False)
    customCLIP = torch.nn.DataParallel(customCLIP, device_ids=[0]).cuda()
    # total number of parameters
    total_params = sum(p.numel() for p in customCLIP.parameters())
    print(f"Total number of parameters: {total_params / 1_000_000:.3f}M")


    ###########################################################
    parameters = filter(lambda p: p.requires_grad, customCLIP.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print("Modified CLIP_model Trainable Parameters: %.3fM" % parameters)

    start_epoch = config.solver.start_epoch

    if config.pretrain:
        if os.path.isfile(config.pretrain):
            print(("=> loading checkpoint '{}'".format(config.pretrain)))
            checkpoint = torch.load(config.pretrain)
            model.load_state_dict(checkpoint["model_state_dict"], strict=False)
            del checkpoint
        else:
            print(("=> no checkpoint found at '{}'".format(config.resume)))

    if config.resume:
        if os.path.isfile(config.resume):
            print(("=> loading checkpoint '{}'".format(config.resume)))
            checkpoint = torch.load(config.resume)
            model.load_state_dict(checkpoint["model_state_dict"], strict=False)
            start_epoch = checkpoint["epoch"]
            print(
                (
                    "=> loaded checkpoint '{}' (epoch {})".format(
                        config.evaluate, start_epoch
                    )
                )
            )
            del checkpoint
        else:
            print(("=> no checkpoint found at '{}'".format(config.pretrain)))

    
    optimizer = _optimizer(config, model)
    lr_scheduler = _lr_scheduler(config, optimizer)

    loss = []
    top_1_acc = []
    ucf_top_1_acc = []
    ucfds_top_1_acc = []
    hmdb_top_1_acc = []
    k600_top_1_acc = []

    best_prec1 = 0.0
    ucf_best_prec1 = 0.0
    ucfds_best_prec1 = 0.0
    hmdb_best_prec1 = 0.0
    k600_best_prec1 = 0.0

    if config.solver.evaluate:
        prec1, prec5 = validate(
            start_epoch,
            val_loader,
            # classes,
            device,
            customCLIP,
            config,
            # num_text_aug,
            working_dir,
            config.data.dataset,
            is_Train=True,
        )
        return

    for k, v in model.named_parameters():
        if v.requires_grad:
            print("{}: {}".format(k, v.requires_grad))

    for epoch in range(start_epoch, config.solver.epochs):
        print(
            "------------------------------------------------------------------------"
        )
        print("Epoch %d start .." % epoch)
        '''
        model_image.train()
        model_text.train()
        '''
        customCLIP.train()
        tic = time.time()
        epoch_loss = []
        for kkk, (prompt_images, list_id) in enumerate((train_loader)):
            prompt_images = prompt_images.to(device)
            list_id = list_id.to(device)
            if config.solver.type != "monitor":
                if (kkk + 1) == 1 or (kkk + 1) % 10 == 0:
                    lr_scheduler.step(epoch + kkk / len(train_loader))
            optimizer.zero_grad()
            # prompt_images = prompter(images)

            prompt_images = prompt_images.view(
                (-1, config.data.num_segments, 3) + prompt_images.size()[-2:]
            )
            b, t, c, h, w = prompt_images.size()

            prompt_images = prompt_images.view(
                -1, c, h, w
            )  
            image_embedding, text_features = customCLIP.module.encode(prompt_images)
            
            image_embedding = image_embedding.view(b, t, -1)
            image_embedding = image_embedding.mean(dim=1, keepdim=False)


            loss_imgs = customCLIP(
                image_embedding, text_features, list_id
            )

            total_loss = loss_imgs
            
            epoch_loss.append(total_loss.item())
            total_loss.backward()

            if device == "cpu":
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)

            if kkk % 100 == 0:
                print(
                    "Epoch:%d  iteration:%d/%d, total loss:%f, lr:%f "
                    % (
                        epoch,
                        kkk,
                        len(train_loader),
                        total_loss.item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )

        if epoch % config.logging.eval_freq == 0:  # and epoch>0
            print("{} val accuracy".format(config.data.dataset))
            prec1, prec5 = validate(
                epoch,
                val_loader,
                # classes,
                device,
                customCLIP,
                config,
                # num_text_aug,
                working_dir,
                config.data.dataset,
                is_Train=True,
            )
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        print("{} Testing: {}/{}".format(config.data.dataset, prec1, best_prec1))

        txt_path = "{}/log.txt".format(working_dir)
        if os.path.exists(txt_path):
            with open(txt_path, "a+") as f:
                f.write("\n")
                f.write(
                    "Epoch:%d  iteration:%d/%d, total loss:%f, lr:%f \n"
                    % (
                        epoch,
                        kkk,
                        len(train_loader),
                        total_loss.item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                f.write(
                    "{} Testing: {}/{}\n".format(config.data.dataset, prec1, best_prec1)
                )
                f.close()
        else:
            with open(txt_path, mode="wt") as f:
                f.write(
                    "Epoch:%d  iteration:%d/%d, total loss:%f, lr:%f \n"
                    % (
                        epoch,
                        kkk,
                        len(train_loader),
                        total_loss.item(),
                        optimizer.param_groups[0]["lr"],
                    )
                )
                f.write(
                    "{} Testing: {}/{}\n".format(config.data.dataset, prec1, best_prec1)
                )
                f.close()

        print("Saving:")
        filename1 = "{}/last_model.pt".format(working_dir)
        top_1_acc.append(prec1 / 100)
        loss.append(np.mean(epoch_loss))
        epoch_saving(epoch, customCLIP, optimizer, filename1)
        if is_best:
            print(
                "Saving best weight based on {} accuracy at epoch {}".format(
                    config.data.dataset, epoch
                )
            )
            best_saving(working_dir, epoch, customCLIP, optimizer, config.data.dataset)

        print("Epoch %d end .." % epoch)
        ##############graph_plot################
        X = list(range(len(loss)))
        plt.plot(X, loss, color="r", label="Training loss")
        plt.plot(
            X, top_1_acc, color="g", label="{} Accuracy".format(config.data.dataset)
        )

        plt.xlabel("Epoch")
        plt.ylabel("Training loss and Accuracy")
        plt.title("Traing graph")
        plt.legend()
        plt.savefig("{}/Graph_plot.png".format(working_dir))
        plt.close()
        print("Time taken by epoch %d:" % epoch, print_time(time.time() - tic))
        print(
            "------------------------------------------------------------------------"
        )
    print("====================Final Testing:=================")
    labels_csv_path = config.data.label_list
    with open(labels_csv_path, "r") as f:
        reader = csv.reader(f)
        _ = next(reader)
        labels2name = {int(row[0]): row[1] for row in reader}
    print(labels2name)
    prec1, prec5 = validate(
        start_epoch,
        val_loader,
        device,
        customCLIP,
        config,
        working_dir,
        config["data"]["dataset"],
        labels2name,
        is_Train=False,
    )

if __name__ == "__main__":
    main()
