import os
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data.dataset import random_split

from image_datasets import MyCocoSemantic, MyCocoSegmentation
from visual_semantic_feature_models import ResNet18VSF, ResNet50VSF

from torchtext.vocab import GloVe


def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm2d') != -1:
        m.eval()


class CSMRLoss(torch.nn.Module):
    """Cosine Similarity Margin Ranking Loss

    Shape:
        - output:

    """

    def __init__(self, margin=1):
        super(CSMRLoss, self).__init__()
        self.margin = torch.tensor(margin).cuda()

    def forward(self, output, target_onehot, embeddings):
        cosine_similarity = torch.mm(output, embeddings) / \
                            torch.mm(torch.sqrt(torch.sum(output ** 2, dim=1, keepdim=True)),
                                     torch.sqrt(torch.sum(embeddings ** 2, dim=0, keepdim=True)))
        false_terms = (1 - target_onehot) * cosine_similarity
        loss = (1 - target_onehot) * (
                self.margin - torch.sum(target_onehot * cosine_similarity, dim=1,
                                        keepdim=True) + false_terms)
        # loss = torch.tensor(1 - target_onehot).cuda() * (
        #             self.margin - torch.sum(torch.tensor(target_onehot).cuda() * cosine_similarity, dim=1,
        #                                     keepdim=True))
        loss = torch.max(torch.tensor(0.).cuda(), loss.float())
        loss = torch.mean(torch.sum(loss, dim=1))
        return loss


def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, embeddings, label_embedding, label_index, k=5):
    print("\nEpoch {} starting.".format(epoch))
    epoch_loss = 0.0
    batch_index = 0
    num_batch = len(train_loader)
    correct = 0.0
    top_k_correct = 0.0
    num_samples = 0
    model.train()
    if args.freeze:
        model.apply(set_bn_eval)
    for _, batch in enumerate(train_loader):
        batch_index += 1
        data, target = batch[0].squeeze(0).cuda(), batch[1].squeeze(0).cuda()
        if data.shape[0] == 1:
            continue
        predict = model(data)
        loss = loss_fn(predict, target, embeddings)
        sorted_predict = torch.argsort(torch.mm(predict, embeddings) /
                                       torch.mm(torch.sqrt(torch.sum(predict ** 2, dim=1, keepdim=True)),
                                                torch.sqrt(torch.sum(embeddings ** 2,
                                                                     dim=0, keepdim=True))),
                                       dim=1, descending=True)[:, :k]
        for i, pred in enumerate(sorted_predict):
            correct += target[i, pred[0]].detach().item()
            t = torch.argsort(-target)
            # pred_word = label_embedding['itos'][label_index[pred[1].item()]]
            top_k_correct += (torch.sum(target[i, pred]) > 0).detach().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 1 == 0:
            train_log = 'Epoch {:2d}\tLoss: {:.6f}\tTrain: [{:4d}/{:4d} ({:.0f}%)]'.format(
                epoch, loss.cpu().item(),
                batch_index, num_batch,
                100. * batch_index / num_batch)
            print(train_log, end='\r')

        epoch_loss += loss.data.detach().item() * len(target)
        num_samples += len(target)
        # gc.collect()
        torch.cuda.empty_cache()
        # break

    epoch_loss /= num_samples
    train_acc = correct / num_samples * 100
    train_top_k_acc = top_k_correct / num_samples * 100
    print()
    print("Train average loss: {:.6f}\t".format(epoch_loss))
    print("Train top-1 accuracy: {:.2f}%".format(train_acc))
    print("Train top-5 accuracy: {:.2f}%".format(train_top_k_acc))
    return epoch_loss, train_acc


def validate(model, loss_fn, valid_loader, embeddings, label_embedding, label_index, k=5):
    model.eval()
    valid_loss = 0
    correct = 0.0
    top_k_correct = 0.0
    num_samples = 0
    for _, batch in enumerate(valid_loader):
        with torch.no_grad():
            data, target = batch[0].squeeze(0).cuda(), batch[1].squeeze(0).cuda()
            predict = model(data)
            sorted_predict = torch.argsort(torch.mm(predict, embeddings) /
                                           torch.mm(torch.sqrt(torch.sum(predict ** 2, dim=1, keepdim=True)),
                                                    torch.sqrt(torch.sum(embeddings ** 2,
                                                                         dim=0, keepdim=True))),
                                           dim=1, descending=True)[:, :k]
            for i, pred in enumerate(sorted_predict):
                correct += target[i, pred[0]].detach().item()
                top_k_correct += (torch.sum(target[i, pred]) > 0).detach().item()

            valid_loss += loss_fn(predict, target, embeddings).data.detach().item() * len(target)
            num_samples += len(target)
        torch.cuda.empty_cache()
        # break

    valid_loss /= num_samples
    valid_acc = correct / num_samples * 100
    valid_top_k_acc = top_k_correct / num_samples * 100
    print('Valid average loss: {:.6f}\t'.format(valid_loss))
    print("Valid top-1 accuracy: {:.2f}%".format(valid_acc))
    print("Valid top-5 accuracy: {:.2f}%".format(valid_top_k_acc))
    return valid_loss, valid_acc


def main(args, glove_dim=300):
    embedding_glove = GloVe(name='6B', dim=glove_dim)
    torch.cuda.empty_cache()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
    }

    if args.level == 'image':
        datasets = {'val': MyCocoSemantic(root='./image_data/coco/val2017',
                                          annFile='./image_data/coco/annotations/instances_val2017.json',
                                          transform=data_transforms['val']),
                    'train': MyCocoSemantic(root='./image_data/coco/train2017',
                                            annFile='./image_data/coco/annotations/instances_train2017.json',
                                            transform=data_transforms['train'])}
    else:
        datasets = {'val': MyCocoSegmentation(root='./image_data/coco/val2017',
                                              annFile='./image_data/coco/annotations/instances_val2017.json',
                                              transform=data_transforms['val']),
                    'train': MyCocoSegmentation(root='./image_data/coco/train2017',
                                                annFile='./image_data/coco/annotations/instances_train2017.json',
                                                transform=data_transforms['train'])}
    dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                                  shuffle=True, num_workers=0)
                   for x in ['train', 'val']}

    if args.model == 'resnet18':
        model = ResNet18VSF(args).cuda()
    else:
        model = ResNet50VSF(args).cuda()
    output_path = 'model_outputs'
    freeze = 'freeze' if args.freeze else 'unfreeze'
    experiment_name = 'vsf_coco_%s_%s_%s_%s_soft' % (args.model, args.layer, args.level, freeze)

    if args.freeze:
        parameters = model.vsf_model.fc.parameters()
    else:
        parameters = model.vsf_model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, verbose=True)

    class_name_file = './image_data/coco/coco-label-paper.txt'
    label_embedding_file = "./image_data/coco/label_embedding.pth"
    if os.path.exists(label_embedding_file):
        label_embedding = torch.load(label_embedding_file)
    else:
        label_to_emb_idx = {}
        emb_idx_to_label = {}
        with open(class_name_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line == '':
                    break
                idx_list = []
                values = line.split()
                for word in values:
                    if word in embedding_glove.stoi:
                        idx = embedding_glove.stoi[word]
                        emb_idx_to_label[idx] = line
                        idx_list.append(idx)
                    else:
                        print(word)
                label_to_emb_idx[line] = idx_list
        label_embedding = {'stoi': label_to_emb_idx, 'itos': emb_idx_to_label}
        torch.save(label_embedding, label_embedding_file)
    loss_fn = CSMRLoss()
    label_index = list(label_embedding['itos'].keys())
    word_embeddings_vec = embedding_glove.vectors[label_index].T.cuda()
    # cosine_similarity = torch.mm(word_embeddings_vec.T, word_embeddings_vec) / \
    #                     torch.mm(torch.sqrt(torch.sum(word_embeddings_vec.T ** 2, dim=1, keepdim=True)),
    #                              torch.sqrt(torch.sum(word_embeddings_vec ** 2, dim=0, keepdim=True))).detach().cpu().numpy()
    print("Model setup...")

    # Train and validate
    best_valid_acc = 0.
    train_accuracies = []
    valid_accuracies = []
    with open(os.path.join(output_path, 'valid_%s.txt' % experiment_name), 'w') as f:
        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(epoch, model, loss_fn, optimizer, dataloaders['train'],
                                                    word_embeddings_vec, label_embedding, label_index)
            ave_valid_loss, valid_acc = validate(model, loss_fn, dataloaders['val'], word_embeddings_vec,
                                                 label_embedding, label_index)
            train_accuracies.append(train_acc)
            valid_accuracies.append(valid_acc)
            scheduler.step(ave_valid_loss)
            f.write('epoch: %d\n' % epoch)
            f.write('train loss: %f\n' % train_loss)
            f.write('train accuracy: %f\n' % train_acc)
            f.write('validation loss: %f\n' % ave_valid_loss)
            f.write('validation accuracy: %f\n' % valid_acc)

            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                print('==> new checkpoint saved')
                f.write('==> new checkpoint saved\n')
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(output_path,
                                'ckpt_%s.pth.tar' % experiment_name))
                plt.figure()
                plt.plot(train_accuracies, '-o', label='train')
                plt.plot(valid_accuracies, '-o', label='valid')
                plt.xlabel('Epoch')
                plt.legend(loc='upper right')
                plt.savefig(os.path.join(output_path, 'accuracies_%s.png' % experiment_name))
                plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='coco', metavar='D')
    parser.add_argument('--batch-size', type=int, default=64, metavar='BS')
    parser.add_argument('--word-embedding-dim', type=int, default=300, metavar='WED')
    parser.add_argument('--epochs', type=int, default=100, metavar='E')
    parser.add_argument('--freeze', type=bool, default=True, metavar='F')
    parser.add_argument('--level', type=str, default='image', metavar='LVL')
    parser.add_argument('--layer', type=str, default='layer4', metavar='L')
    parser.add_argument('--model', type=str, default='resnet50', metavar='L')
    args = parser.parse_args()
    main(args)
