import argparse
import time
import math
from os import path, makedirs

import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
from torchvision import datasets
from torchvision import transforms

from simsiam.loader import TwoCropsTransform
from simsiam.model_factory import SimSiam
from simsiam.criterion import SimSiamLoss, TwinsLoss, SepCLLoss_v1_OTL1, SepCLLoss_v1_SimSiam_OTL1
from simsiam.validation import KNNValidation
from simsiam.barlowtwins import BarlowTwins
from simsiam.SepCL_v1 import SepCL_v1
from simsiam.SepCL_v1_Simsiam import SepCL_v1_Simsiam

parser = argparse.ArgumentParser('arguments for training')
parser.add_argument('--data_root', type=str, help='path to dataset directory')
parser.add_argument('--exp_dir', type=str, help='path to experiment directory')
parser.add_argument('--trial', type=str, default='1', help='trial id')
parser.add_argument('--img_dim', default=32, type=int)

parser.add_argument('--arch', default='resnet18', help='model name is used for training')

parser.add_argument('--feat_dim', default=2048, type=int, help='feature dimension')
parser.add_argument('--num_proj_layers', type=int, default=2, help='number of projection layer')
parser.add_argument('--batch_size', type=int, default=512, help='batch_size')
parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=800, help='number of training epochs')
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--loss_version', default='simplified', type=str,
                    choices=['simplified', 'original'],
                    help='do the same thing but simplified version is much faster. ()')
parser.add_argument('--print_freq', default=10, type=int, help='print frequency')
parser.add_argument('--eval_freq', default=5, type=int, help='evaluate model frequency')
parser.add_argument('--save_freq', default=100, type=int, help='save model frequency')
parser.add_argument('--resume', default=None, type=str, help='path to latest checkpoint')

parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

parser.add_argument('--last_dim', default=-1, type=int,
                    help='feature dimension (default: -1)')

## new
parser.add_argument('--aug', default='aug1', type=str, choices=['aug1', 'aug1_2', 'aug2'])
parser.add_argument('--net', default='simsiam', type=str, choices=['simsiam', 'barlowtwins', 'SepCL_v1_OTL1', 'SepCL_v1_Simsiam_OTL1'])
parser.add_argument('--hp1', default=1.0, type=float, help='hp1')
parser.add_argument('--lambd', default=0.005, type=float, help='lambd')
## SepCL_v1
#parser.add_argument('--sep_lambd', default=1.0, type=float, help='sep_lambd')
parser.add_argument('--sep_lambd', default="1.0", type=str, help='sep_lambd, 1.0, or 1.0-300-300')
parser.add_argument('--projector_img', default='2048-2048-2048', type=str, help='projector_img')
parser.add_argument('--projector_aug', default='2048-2048-2048', type=str, help='projector_aug')
## warmup
parser.add_argument('--warmup_epochs', type=int, default=0, help='number of warmup_epochs')
## dataset_type
parser.add_argument('--dataset_type', default='CIFAR10', type=str, choices=['CIFAR10', 'CIFAR100', 'STL10'])

args = parser.parse_args()

print("args.sep_lambd", args.sep_lambd)
if "-" in args.sep_lambd:
    args.sep_lambd = list(map(float, args.sep_lambd.split('-')))
else:
    args.sep_lambd = float(args.sep_lambd)

if args.dataset_type == "STL10":
    args.img_dim = 96
print("args.img_dim", args.img_dim)

if args.projector_img == "8192-8192-8192":
    args.feat_dim = 8192


def main():
    if not path.exists(args.exp_dir):
        makedirs(args.exp_dir)

    trial_dir = path.join(args.exp_dir, args.trial)
    logger = SummaryWriter(trial_dir)
    print(vars(args))

    if args.aug == "aug1":
        train_transforms = transforms.Compose([
            transforms.RandomResizedCrop(args.img_dim, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        print("args.aug", args.aug)

    elif args.aug == "aug1_2":
        train_transforms = transforms.Compose([
            transforms.RandomResizedCrop(args.img_dim, scale=(0.2, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(degrees=(-90, 90)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        print("args.aug", args.aug)

    if args.dataset_type == "CIFAR10":
        train_set = datasets.CIFAR10(root=args.data_root,
                                     train=True,
                                     download=True,
                                     transform=TwoCropsTransform(train_transforms))
    elif args.dataset_type == "CIFAR100":
        train_set = datasets.CIFAR100(root=args.data_root,
                                     train=True,
                                     download=True,
                                     transform=TwoCropsTransform(train_transforms))
    elif args.dataset_type == "STL10":
        train_set = datasets.STL10(root=args.data_root,
                                     split="train+unlabeled",
                                     download=True,
                                     transform=TwoCropsTransform(train_transforms))


    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)

    if args.net == "simsiam":
        model = SimSiam(args)
    elif args.net == "barlowtwins":
        model = BarlowTwins(args)
        print("barlowtwins loaded")
    elif args.net == "SepCL_v1_OTL1":
        model = SepCL_v1(args)
        print("SepCL_v1 loaded")
    elif args.net == "SepCL_v1_Simsiam_OTL1":
        model = SepCL_v1_Simsiam(args)
        print("SepCL_v1_Simsiam loaded")

    print(model) # print model after SyncBatchNorm

    print("args.learning_rate", args.learning_rate)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.net == "simsiam":
        criterion = SimSiamLoss(args.loss_version)
    elif args.net == "barlowtwins":
        criterion = TwinsLoss(args)
        print("TwinsLoss loaded")
    elif args.net == "SepCL_v1_OTL1":
        criterion = SepCLLoss_v1_OTL1(args)
        print("SepCLLoss_v1_OTL1 loaded")
    elif args.net == "SepCL_v1_Simsiam_OTL1":
        criterion = SepCLLoss_v1_SimSiam_OTL1(args)
        print("SepCLLoss_v1_SimSiam_OTL1 loaded")


    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)
        cudnn.benchmark = True

    start_epoch = 1
    if args.resume is not None:
        if path.isfile(args.resume):
            start_epoch, model, optimizer = load_checkpoint(model, optimizer, args.resume)
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, start_epoch))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    # routine
    best_acc = 0.0
    if args.net == "SepCL_v1_OTL1" or args.net == "SepCL_v1_Simsiam_OTL1":
        validation = KNNValidation(args, model.backbone, model.projector_img)
    else:
        validation = KNNValidation(args, model.encoder)

    if args.warmup_epochs > 0:
        optimizer2 = optim.SGD(model.parameters(),
                      lr=1e-9,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)

        for wp_epoch in range(0, args.warmup_epochs):
            adjust_learning_rate_wp(optimizer2, wp_epoch, args)
            print("Warmming up...")

            train_loss = train(train_loader, model, criterion, optimizer2, wp_epoch, args)
            logger.add_scalar('Loss/train', train_loss, wp_epoch)

            if wp_epoch % args.eval_freq == 0:
                print("Validating...")
                val_top1_acc = validation.eval()
                print('Top1: {}'.format(val_top1_acc))

    for epoch in range(start_epoch, args.epochs+1):

        adjust_learning_rate(optimizer, epoch, args)
        print("Training...")

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer, epoch, args)
        logger.add_scalar('Loss/train', train_loss, epoch)

        if epoch % args.eval_freq == 0:
            print("Validating...")
            val_top1_acc = validation.eval()
            print('Top1: {}'.format(val_top1_acc))

            # save the best model
            if val_top1_acc > best_acc:
                best_acc = val_top1_acc

                save_checkpoint(epoch, model, optimizer, best_acc,
                                path.join(trial_dir, '{}_best.pth'.format(args.trial)),
                                'Saving the best model!')
            logger.add_scalar('Acc/val_top1', val_top1_acc, epoch)

        # save the model
        if epoch % args.save_freq == 0:
            save_checkpoint(epoch, model, optimizer, val_top1_acc,
                            path.join(trial_dir, 'ckpt_epoch_{}_{}.pth'.format(epoch, args.trial)),
                            'Saving...')

    print('Best accuracy:', best_acc)

    # save model
    save_checkpoint(epoch, model, optimizer, val_top1_acc,
                    path.join(trial_dir, '{}_last.pth'.format(args.trial)),
                    'Saving the model at the last epoch.')


def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    if args.net == "simsiam":
        losses1 = AverageMeter('Loss1', ':.4e')
        losses2 = AverageMeter('Loss2', ':.4e')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, losses, losses1, losses2],
            prefix="Epoch: [{}]".format(epoch))
    elif args.net == "barlowtwins":
        losses_on_diag = AverageMeter('Loss_on_diag', ':.4e')
        losses_off_diag = AverageMeter('Loss_off_diag', ':.4e')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, losses, losses_on_diag, losses_off_diag],
            prefix="Epoch: [{}]".format(epoch))
    elif args.net == "SepCL_v1_OTL1":
        losses_img = AverageMeter('Loss_img', ':.4e')
        losses_aug = AverageMeter('Loss_aug', ':.4e')
        losses_img_on_diag = AverageMeter('Loss_img_on_diag', ':.4e')
        losses_img_off_diag = AverageMeter('Loss_img_off_diag', ':.4e')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, losses, losses_img, losses_aug, losses_img_on_diag, losses_img_off_diag],
            prefix="Epoch: [{}]".format(epoch))
    elif args.net == "SepCL_v1_Simsiam_OTL1":
        losses_img = AverageMeter('Loss_img', ':.4e')
        losses_aug = AverageMeter('Loss_aug', ':.4e')
        losses1_img = AverageMeter('Loss1_img', ':.4e')
        losses2_img = AverageMeter('Loss2_img', ':.4e')
        losses1_aug = AverageMeter('Loss1_aug', ':.4e')
        losses2_aug = AverageMeter('Loss2_aug', ':.4e')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, losses, losses_img, losses_aug, losses1_img, losses2_img, losses1_aug, losses2_aug],
            prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, _) in enumerate(train_loader):

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        outs = model(im_aug1=images[0], im_aug2=images[1])
        if args.net == "simsiam":
            loss_list = criterion(outs['z1'], outs['z2'], outs['p1'], outs['p2'])
        elif args.net == "barlowtwins":
            loss_list = criterion(outs['za'], outs['zb'])
        elif args.net == "SepCL_v1_OTL1":
            loss_list = criterion(outs['za_img'], outs['zb_img'], outs['za_aug'], outs['zb_aug'])
        elif args.net == "SepCL_v1_Simsiam_OTL1":
            loss_list = criterion(outs['za_img'], outs['zb_img'], outs['za_aug'], outs['zb_aug'], outs['pa_img'], outs['pb_img'], outs['pa_aug'], outs['pb_aug'])

        loss = loss_list[0]
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        losses.update(loss.item(), images[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()

        if args.net == "simsiam":
            losses1.update(loss_list[1].item(), images[0].size(0))
            losses2.update(loss_list[2].item(), images[0].size(0))
        elif args.net == "barlowtwins":
            losses_on_diag.update(loss_list[1].item(), images[0].size(0))
            losses_off_diag.update(loss_list[2].item(), images[0].size(0))
        elif args.net == "SepCL_v1_OTL1":
            losses_img.update(loss_list[1].item(), images[0].size(0))
            losses_aug.update(loss_list[4].item(), images[0].size(0))
            losses_img_on_diag.update(loss_list[2].item(), images[0].size(0))
            losses_img_off_diag.update(loss_list[3].item(), images[0].size(0))
        elif args.net == "SepCL_v1_Simsiam_OTL1":
            losses_img.update(loss_list[1].item(), images[0].size(0))
            losses_aug.update(loss_list[4].item(), images[0].size(0))
            losses1_img.update(loss_list[2].item(), images[0].size(0))
            losses2_img.update(loss_list[3].item(), images[0].size(0))
            losses1_aug.update(loss_list[5].item(), images[0].size(0))
            losses2_aug.update(loss_list[6].item(), images[0].size(0))

        if i % args.print_freq == 0:
            progress.display(i)

    return losses.avg


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.learning_rate
    # cosine lr schedule
    lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    #print(epoch, lr)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_learning_rate_wp(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = 1e-9
    # cosine lr schedule
    lr += args.learning_rate * (epoch/args.warmup_epochs) #0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    #print(epoch, lr)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def save_checkpoint(epoch, model, optimizer, acc, filename, msg):
    state = {
        'epoch': epoch,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'top1_acc': acc
    }
    torch.save(state, filename)
    print(msg)


def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename, map_location='cuda:0')
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    return start_epoch, model, optimizer


if __name__ == '__main__':
    main()



