import os
import cv2
import pickle
import argparse
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data.dataset import random_split

from image_datasets import VisualGenome
from torchvision import models

from torchtext.vocab import GloVe

from visual_semantic_feature_models import Identity


def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm2d') != -1:
        m.eval()
    if classname.find('BatchNorm1d') != -1:
        m.eval()


class CSMRLoss(torch.nn.Module):
    """Cosine Similarity Margin Ranking Loss

    Shape:
        - output:

    """

    def __init__(self, weight, margin=0.3):
        super(CSMRLoss, self).__init__()
        self.margin = torch.tensor(margin).cuda()
        self.weight = weight.cuda()

    def forward(self, output, target_onehot, embeddings, train_label_idx, filter_max):
        cosine_similarity = torch.mm(output, embeddings) / \
                            torch.mm(torch.sqrt(torch.sum(output ** 2, dim=1, keepdim=True)),
                                     torch.sqrt(torch.sum(embeddings ** 2, dim=0, keepdim=True)))
        cosine_similarity = cosine_similarity[:, train_label_idx]
        indices = torch.sum(target_onehot, dim=1) > 0
        if torch.sum(indices) == 0:
            return torch.tensor(0.).cuda()
        cosine_similarity = cosine_similarity[indices]
        target_onehot = target_onehot[indices]
        false_terms = (1 - target_onehot) * cosine_similarity
        # tmp, _ = torch.min(target_onehot * cosine_similarity, dim=1)
        tmp = torch.sum(target_onehot * cosine_similarity, dim=1) / torch.sum(target_onehot, dim=1)
        loss = (1 - target_onehot) * (self.margin - tmp.unsqueeze(1) + false_terms)

        loss[torch.isnan(loss)] = 0.
        loss = torch.max(torch.tensor(0.).cuda(), loss.float())
        loss = torch.sum(loss, dim=1)
        loss = torch.mean(loss)
        return loss


def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, embeddings, output_path, experiment_name,
                    train_label_idx, labels, k=5):
    print("\nEpoch {} starting.".format(epoch))
    epoch_loss = 0.0
    batch_index = 0
    num_batch = len(train_loader)
    correct = 0.0
    top_k_correct = 0.0
    model.train()
    if args.freeze:
        model.apply(set_bn_eval)
    for _, batch in enumerate(train_loader):
        batch_index += 1
        data, target, mask = batch[0].cuda(), batch[1].squeeze(0).cuda(), batch[2].squeeze(0).cuda()
        if data.shape[0] == 1:
            model.apply(set_bn_eval)
        predict = data.clone()
        for name, module in model._modules.items():
            if name is 'fc':
                predict = torch.flatten(predict, 1)
            predict = module(predict)
            if name is args.layer:
                filter_min, _ = torch.min(predict, dim=-1, keepdim=True)
                filter_min, _ = torch.min(filter_min, dim=-2, keepdim=True)
                filter_max, _ = torch.max(predict, dim=-1, keepdim=True)
                filter_max, _ = torch.max(filter_max, dim=-2, keepdim=True)
                filter_max[filter_max == 0] = 1e-6
                predict = predict * mask
                target = target[:, train_label_idx]
        loss = loss_fn(predict, target, embeddings, train_label_idx, filter_max.squeeze())
        sorted_predict = torch.argsort(torch.mm(predict, embeddings) /
                                       torch.mm(torch.sqrt(torch.sum(predict ** 2, dim=1, keepdim=True)),
                                                torch.sqrt(torch.sum(embeddings ** 2,
                                                                     dim=0, keepdim=True))),
                                       dim=1, descending=True)[:, :k]
        for i, pred in enumerate(sorted_predict):
            correct += target[i, pred[0]].detach().item()
            top_k_correct += (torch.sum(target[i, pred]) > 0).detach().item()

        optimizer.zero_grad()
        if loss.requires_grad:
            loss.backward()
            optimizer.step()

        if batch_index % 1 == 0:
            train_log = 'Epoch {:2d}\tLoss: {:.6f}\tTrain: [{:4d}/{:4d} ({:.0f}%)]'.format(
                epoch, loss.cpu().item(),
                batch_index, num_batch,
                100. * batch_index / num_batch)
            print(train_log, end='\r')

        if batch_index % 1000 == 0:
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, os.path.join(output_path,
                            'ckpt_%s_tmp.pth.tar' % experiment_name))

        epoch_loss += loss.data.detach().item()
        # gc.collect()
        torch.cuda.empty_cache()
        # break

    epoch_loss /= len(train_loader.dataset)
    train_acc = correct / (len(train_loader) * train_loader.batch_size) * 100
    train_top_k_acc = top_k_correct / (len(train_loader) * train_loader.batch_size * k) * 100
    print()
    print("Train average loss: {:.6f}\t".format(epoch_loss))
    print("Train top-1 accuracy: {:.2f}%".format(train_acc))
    print("Train top-5 accuracy: {:.2f}%".format(train_top_k_acc))
    return epoch_loss, train_acc


def validate(model, loss_fn, valid_loader, embeddings, train_label_idx, k=5):
    model.eval()
    valid_loss = 0
    correct = 0.0
    top_k_correct = 0.0
    for _, batch in enumerate(valid_loader):
        with torch.no_grad():
            data, target, mask = batch[0].cuda(), batch[1].squeeze(0).cuda(), batch[2].squeeze(0).cuda()
            predict = data.clone()
            for name, module in model._modules.items():
                if name is 'fc':
                    predict = torch.flatten(predict, 1)
                predict = module(predict)
                if name is args.layer:
                    filter_min, _ = torch.min(predict, dim=-1, keepdim=True)
                    filter_min, _ = torch.min(filter_min, dim=-2, keepdim=True)
                    filter_max, _ = torch.max(predict, dim=-1, keepdim=True)
                    filter_max, _ = torch.max(filter_max, dim=-2, keepdim=True)
                    filter_max[filter_max == 0] = 1e-6
                    predict = predict * mask
                    target = target[:, train_label_idx]
            sorted_predict = torch.argsort(torch.mm(predict, embeddings) /
                                           torch.mm(torch.sqrt(torch.sum(predict ** 2, dim=1, keepdim=True)),
                                                    torch.sqrt(torch.sum(embeddings ** 2,
                                                                         dim=0, keepdim=True))),
                                           dim=1, descending=True)[:, :k]
            for i, pred in enumerate(sorted_predict):
                correct += target[i, pred[0]].detach().item()
                top_k_correct += (torch.sum(target[i, pred]) > 0).detach().item()

            valid_loss += loss_fn(predict, target, embeddings, train_label_idx, filter_max.squeeze()).data.detach().item()
        torch.cuda.empty_cache()
        # break

    valid_loss /= len(valid_loader.dataset)
    valid_acc = correct / (len(valid_loader) * valid_loader.batch_size) * 100
    valid_top_k_acc = top_k_correct / (len(valid_loader) * valid_loader.batch_size * k) * 100
    print('Valid average loss: {:.6f}\t'.format(valid_loss))
    print("Valid top-1 accuracy: {:.2f}%".format(valid_acc))
    print("Valid top-5 accuracy: {:.2f}%".format(valid_top_k_acc))
    return valid_loss, valid_acc


def main(args, train_rate=0.9):
    embedding_glove = GloVe(name='6B', dim=args.word_embedding_dim)
    torch.cuda.empty_cache()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
    }

    dataset = VisualGenome(transform=data_transforms['val'])
    datasets = {}
    train_size = int(train_rate * len(dataset))
    test_size = len(dataset) - train_size
    torch.manual_seed(0)
    datasets['train'], datasets['val'] = random_split(dataset, [train_size, test_size])
    dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                                  shuffle=True, num_workers=0)
                   for x in ['train', 'val']}

    model = models.resnet50(pretrained=True)
    # Set grad to false to freeze.
    if args.freeze:
        for param in model.parameters():
            param.requires_grad = False

    # Default sets requires_grad to true,
    # so final fc can be optimized.
    if args.layer == 'layer4':
        NUM_FEATURES = 512
    else:
        NUM_FEATURES = 256
        model.layer4 = Identity()
    model.fc = nn.Sequential(nn.BatchNorm1d(NUM_FEATURES * 4),
                             nn.Dropout(0.1),
                             nn.Linear(in_features=NUM_FEATURES * 4, out_features=NUM_FEATURES * 4, bias=True),
                             nn.ReLU(),
                             nn.BatchNorm1d(NUM_FEATURES * 4),
                             nn.Dropout(0.1),
                             nn.Linear(in_features=NUM_FEATURES * 4, out_features=args.word_embedding_dim, bias=True))
    model = model.cuda()
    output_path = 'model_outputs'
    freeze = 'freeze' if args.freeze else 'unfreeze'
    experiment_name = 'vsf_vg_resnet50_%s_%s_%s_all' % (args.layer, args.level, freeze)


    if args.freeze:
        parameters = model.fc.parameters()
    else:
        parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, verbose=True)

    label_index_file = "./image_data/visual_genome_python_driver-master/obj_per_image_cleaned_idx_convert.pkl"
    with open(label_index_file, 'rb') as f:
        labels = pickle.load(f)
    label_index = []
    for label in labels:
        label_index.append(embedding_glove.stoi[label])
    np.random.seed(0)
    # train_label_index = np.random.choice(np.arange(len(label_index)), int(len(label_index) * .7))
    train_label_index = np.arange(len(label_index))
    loss_fn = CSMRLoss(dataset._lvl_supervision)
    word_embeddings_vec = embedding_glove.vectors[label_index].T.cuda()
    print("Model setup...")

    # Train and validate
    best_valid_loss = 9999999.
    train_accuracies = []
    valid_accuracies = []
    with open(os.path.join(output_path, 'valid_%s.txt' % experiment_name), 'w') as f:
        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(epoch, model, loss_fn, optimizer, dataloaders['train'],
                                                    word_embeddings_vec, output_path, experiment_name,
                                                    train_label_index, list(labels.keys()))
            ave_valid_loss, valid_acc = validate(model, loss_fn, dataloaders['val'],
                                                 word_embeddings_vec, train_label_index)
            train_accuracies.append(train_acc)
            valid_accuracies.append(valid_acc)
            scheduler.step(ave_valid_loss)
            f.write('epoch: %d\n' % epoch)
            f.write('train loss: %f\n' % train_loss)
            f.write('train accuracy: %f\n' % train_acc)
            f.write('validation loss: %f\n' % ave_valid_loss)
            f.write('validation accuracy: %f\n' % valid_acc)

            if ave_valid_loss < best_valid_loss:
                best_valid_loss = ave_valid_loss
                print('==> new checkpoint saved')
                f.write('==> new checkpoint saved\n')
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(output_path,
                                'ckpt_%s.pth.tar' % experiment_name))
                plt.figure()
                plt.plot(train_loss, '-o', label='train')
                plt.plot(ave_valid_loss, '-o', label='valid')
                plt.xlabel('Epoch')
                plt.legend(loc='upper right')
                plt.savefig(os.path.join(output_path, 'losses_%s.png' % experiment_name))
                plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=1, metavar='BS')
    parser.add_argument('--word-embedding-dim', type=int, default=300, metavar='WED')
    parser.add_argument('--epochs', type=int, default=100, metavar='E')
    parser.add_argument('--freeze', type=bool, default=True, metavar='F')
    parser.add_argument('--level', type=str, default='image', metavar='L')
    parser.add_argument('--layer', type=str, default='layer3', metavar='L')
    args = parser.parse_args()
    main(args)
