import os
import argparse
import json
import sys
import time
import torch
import torchvision
import numpy as np
from pathlib import Path
from torch import nn, optim
from torchvision import datasets, transforms
from trainer import PredictionMLP, ProjectionMLP, Identity
from datasets import *
from models import EquiRotate, MultiLinearAlign
from utils import fix_seed
from lightcnn import LightCNN

parser = argparse.ArgumentParser(description='Evaluate resnet50 features on ImageNet')
parser.add_argument('--data', default='/data',type=Path, metavar='DIR',
                    help='path to dataset')
parser.add_argument('--pretrained-dir', default='./experiments/stl10_escnn18',
                    type=Path, metavar='FILE',
                    help='path to pretrained model')
parser.add_argument('--weights', default='freeze', type=str,
                    choices=('finetune', 'freeze'),
                    help='finetune or freeze resnet weights')
parser.add_argument('--train-percent', default=100, type=int,
                    choices=(100, 10, 1),
                    help='size of traing set in percent')
parser.add_argument('--workers', default=8, type=int, metavar='N',
                    help='number of data loader workers')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--batch-size', default=256, type=int, metavar='N',
                    help='mini-batch size')
parser.add_argument('--lr-backbone', default=0.0, type=float, metavar='LR',
                    help='backbone base learning rate')
parser.add_argument('--lr-classifier', default=1.0, type=float, metavar='LR',
                    help='classifier base learning rate')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
                    help='weight decay')
parser.add_argument('--print-freq', default=10, type=int, metavar='N',
                    help='print frequency')
parser.add_argument('--gpus', default='0', type=str)
parser.add_argument('--arch', default='escnn18', type=str, help='model architecture',
                    choices=['resnet18', 'resnet50', 'lightcnn'])
parser.add_argument('--num-classes', default=196, type=int)
parser.add_argument('--pretrain-set', default='stl10', type=str, help='pretrain dataset',
                    choices=['stl10','stl10-R','imagenet100','imagenet100-R','imagenet1k','caltech256','mtarsi','imagenet100_app','EMNIST','EMNIST-R','AffNIST','AffNIST-R'])
parser.add_argument('--eval-set', default='stl10', type=str, help='evaluation dataset',
                    choices=['stl10','imagenet100','stanford_cars','fgvc_aircraft','cub_200_2011','cifar10','cifar100','caltech256','imagenet1k','mtarsi','oxford_flowers','EMNIST','RotNIST','AffNIST'])
parser.add_argument('--connector', default='softmax', type=str, help='equivariance connection map',
                    choices=['softmax', 'identity', 'tanh', 'shift'])
parser.add_argument('--rotated', action='store_true')
parser.add_argument('--save_linear_probe', action='store_true')
parser.add_argument('--random_rotation', action='store_true', help='To have Dataset with Random Rotation Augmentation')
parser.add_argument('--iterations', default=0, type=int)
parser.add_argument('--lr_scheduler', default='', type=str, choices=['','cosine','step'])
parser.add_argument('--pred_hidden_dim', default=512, type=int)
parser.add_argument('--gie', action='store_true') 
parser.add_argument('--imsize', default=96, type=int)
parser.add_argument('--eqv_loss_type', default='mse', type=str, choices=['mse', 'infonce', 'shift'])
parser.add_argument('--use_mlp', default=False, action='store_true')
parser.add_argument('--R_choices', default='linear', type=str,
                    choices=['linear','multi_linear_align',])
parser.add_argument('--deep', default=False, action='store_true')
parser.add_argument('--depth', default=3, type=int)
parser.add_argument('--multi_num', default=4, type=int)
parser.add_argument('--emnist_type', default='byclass', type=str, choices=['byclass', 'balanced'])
parser.add_argument('--lightcnn_feat_dim', default=128, type=int)

def main():
    args = parser.parse_args()
    seed = np.random.randint(0, 10000)
    args.seed = seed 
    fix_seed(seed)
    args.ngpus_per_node = torch.cuda.device_count()
    args.gpus = [int(x) for x in args.gpus.split(',')]
    main_worker(args)

def main_worker(args):
    if args.rotated:        
        if not args.gie:
            statfile_name = args.eval_set + f'_stats_linear_probe_rotated_{args.iterations}.txt'
        else:
            statfile_name = args.eval_set + f'_stats_linear_probe_rotated_inv_{args.iterations}.txt'
    else:
        if not args.gie:
            statfile_name = args.eval_set + f'_stats_linear_probe_{args.iterations}.txt'
        else:
            statfile_name = args.eval_set + f'_stats_linear_probe_inv_{args.iterations}.txt'
            
    stats_file = open(args.pretrained_dir / statfile_name, 'a', buffering=1)
    torch.backends.cudnn.benchmark = True

    print('Seed Num: {}'.format(args.seed))
    print('Seed Num: {}'.format(args.seed), file=stats_file)
    # Data loading code
    if not args.random_rotation:
        train_dataset, val_dataset, num_classes = load_eval_datasets(args)
    else:
        train_dataset, val_dataset, num_classes = load_eval_random_rotation_sets(args)

    args.num_classes = num_classes
    
    kwargs = dict(batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, **kwargs)

    model = Encoder(args)
    model.cuda()

    state_dict = torch.load(os.path.join(args.pretrained_dir, 'final.pth'), map_location='cpu')
    missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict['backbone'], strict=False)
    print(missing_keys)
    print(unexpected_keys)
    
    if args.gie:
        missing_keys, unexpected_keys = model.predictor.load_state_dict(state_dict['predictor'], strict=False)
        print(missing_keys)
        print(unexpected_keys)
        
        missing_keys, unexpected_keys = model.equi_transform.load_state_dict(state_dict['equi_transform'], strict=False)
        print(missing_keys)
        print(unexpected_keys)

    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()
    if args.weights == 'freeze':
        model.backbone.requires_grad_(False)
        model.fc.requires_grad_(True)
        if args.gie:
            model.predictor.requires_grad_(False)
            model.equi_transform.requires_grad_(False)
    
    classifier_parameters, model_parameters = [], []
    for name, param in model.named_parameters():
        if name in {'fc.weight', 'fc.bias'}:
            classifier_parameters.append(param)
        else:
            model_parameters.append(param)
    
    criterion = nn.CrossEntropyLoss().cuda()

    param_groups = [dict(params=classifier_parameters, lr=args.lr_classifier)]
    if args.weights == 'finetune':
        param_groups.append(dict(params=model_parameters, lr=args.lr_backbone))
    optimizer = optim.SGD(param_groups, 0, momentum=0.9, weight_decay=args.weight_decay)
    if args.lr_scheduler=='cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    else:
        def lr_lambda(epoch):
            if epoch < 30:
                return 1.0
            elif epoch < 40:
                return 0.1
            elif epoch < 50:
                return 0.01
            else:
                return 0.001
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    start_epoch = 0
    best_acc = argparse.Namespace(top1=0, top5=0)

    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        # train
        if args.weights == 'finetune':
            model.train()
        elif args.weights == 'freeze':
            model.backbone.eval()
            if args.gie:
                model.predictor.eval()
                model.equi_transform.eval()
        else:
            assert False
        for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)):
            output = model(images.cuda(non_blocking=True))
            loss = criterion(output, target.cuda(non_blocking=True))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % args.print_freq == 0:
                pg = optimizer.param_groups
                lr_classifier = pg[0]['lr']
                lr_backbone = pg[1]['lr'] if len(pg) == 2 else 0
                stats = dict(epoch=epoch, step=step, lr_backbone=lr_backbone,
                                lr_classifier=lr_classifier, loss=loss.item(),
                                time=int(time.time() - start_time))
                print(json.dumps(stats))
                print(json.dumps(stats), file=stats_file)
        
        if args.lr_scheduler:
            scheduler.step()

        if args.save_linear_probe:  
            state = dict(epoch=epoch+1, model=model.fc.state_dict(), optimizer=optimizer.state_dict())    
            if args.rotated:
                if args.gie:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_rotated_hx_{args.iterations}.pth')
                else:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_rotated_{args.iterations}.pth')
            else:
                if args.gie:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_hx_{args.iterations}.pth')
                else:
                    torch.save(state, args.pretrained_dir / f'checkcpoint_fc_{args.iterations}.pth')
        
        # if (epoch+1)%100==0:
        #     if args.rotated:
        #         if args.gie:
        #             torch.save(state, args.pretrained_dir / f'final_fc_rotated_hx_{epoch+1}_{args.iterations}.pth')
        #         else:
        #             torch.save(state, args.pretrained_dir / f'final_fc_rotated_{epoch+1}_{args.iterations}.pth')
        #     else:
        #         if args.gie:
        #             torch.save(state, args.pretrained_dir / f'final_fc_hx_{epoch+1}_{args.iterations}.pth')
        #         else:
        #             torch.save(state, args.pretrained_dir / f'final_fc_{epoch+1}_{args.iterations}.pth')
        
        
        # evaluate
        model.eval()
        top1 = AverageMeter('Acc@1')
        top5 = AverageMeter('Acc@5')
        with torch.no_grad():
            for images, target in val_loader:
                output = model(images.cuda(non_blocking=True))
                acc1, acc5 = accuracy(output, target.cuda(non_blocking=True), topk=(1, 5))
                top1.update(acc1[0].item(), images.size(0))
                top5.update(acc5[0].item(), images.size(0))
        best_acc.top1 = max(best_acc.top1, top1.avg)
        best_acc.top5 = max(best_acc.top5, top5.avg)
        stats = dict(epoch=epoch, acc1=top1.avg, acc5=top5.avg, best_acc1=best_acc.top1, best_acc5=best_acc.top5)
        print(json.dumps(stats))
        print(json.dumps(stats), file=stats_file)
        
    if args.save_linear_probe: 
        state = dict(fc=model.fc.state_dict())
        if args.rotated:
            if args.gie:
                torch.save(state, args.pretrained_dir / f'final_fc_rotated_hx_{args.iterations}.pth')
            else:
                torch.save(state, args.pretrained_dir / f'final_fc_rotated_{args.iterations}.pth')
        else:
            if args.gie:
                torch.save(state, args.pretrained_dir / f'final_fc_hx_{args.iterations}.pth')
            else:
                torch.save(state, args.pretrained_dir / f'final_fc_{args.iterations}.pth')


class Encoder(nn.Module):

    def __init__(self, args):
        super().__init__()

        self.order = 4
        self.gie = args.gie
        self.args= args

        if args.arch=='resnet18':
            self.backbone = torchvision.models.resnet18(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 512

        elif args.arch=='resnet50':
            self.backbone = torchvision.models.resnet50(zero_init_residual=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048
        
        elif args.arch=='lightcnn':
            self.backbone = LightCNN(feat_dim=args.lightcnn_feat_dim)
            self.backbone.fc = nn.Identity()
            feature_dim = args.lightcnn_feat_dim

        if self.gie:
            self.predictor = PredictionMLP(feature_dim, args.pred_hidden_dim, 4, args.pretrain_set)
            
            if not args.eqv_loss_type == 'shift':
                if args.R_choices == 'linear':
                    self.equi_transform = EquiRotate(feature_dim, self.args.use_mlp)
                elif args.R_choices == 'multi_linear_align':
                    self.equi_transform = MultiLinearAlign(args, feature_dim)

        if args.connector=='softmax':
            self.connector = torch.nn.Softmax(dim=1)
        elif args.connector=='shift':
            self.connector = None
            permute_patterns = [torch.roll(torch.arange(self.order), shifts=-i).tolist() for i in range(self.order)]
            self.permute_tensor = torch.tensor(permute_patterns).cuda()


        self.fc = nn.Linear(feature_dim, args.num_classes)
    
    def forward(self, x):

        if not self.gie:
            x = self.backbone(x)
            out = self.fc(x)
        else:
            FX = self.backbone(x)
            eqv_logit = self.predictor(FX).flatten(1)
            b,c = FX.shape
            if self.connector:
                eqv_score = self.connector(eqv_logit)
                
                if self.args.eqv_loss_type == 'shift':
                    FX_re = FX.reshape([b, c//self.order, self.order])
                    permuted_reprs = [torch.roll(FX_re, shifts=-i, dims=2).reshape([b,c]) for i in range(self.order)]
                    permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                    HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                else:
                    if self.args.R_choices == 'linear':
                        permuted_reprs = [FX]
                        for _ in range(self.order-1):
                            permuted_reprs.append(self.equi_transform(permuted_reprs[-1]))

                        order = [0, 3, 2, 1]
                        permuted_reprs = [permuted_reprs[i] for i in order]
                        permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                        HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                    
                    elif self.args.R_choices == 'multi_linear_align':
                        if self.args.multi_num == 3:
                            permuted_reprs = [FX]
                            
                            order = [
                                torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                                torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                                torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                            ]
                            
                            for i in range(self.order-1):
                                permuted_reprs.append(self.equi_transform(FX, order[i]))
                            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()
                        else:
                            permuted_reprs = []
                            
                            order = [
                                torch.full((FX.size()[0],), 0, dtype=torch.long, device=FX.device),
                                torch.full((FX.size()[0],), 3, dtype=torch.long, device=FX.device),
                                torch.full((FX.size()[0],), 2, dtype=torch.long, device=FX.device),
                                torch.full((FX.size()[0],), 1, dtype=torch.long, device=FX.device)
                            ]
                            
                            for i in range(self.order):
                                permuted_reprs.append(self.equi_transform(FX, order[i]))
                            permuted_reprs = torch.stack(permuted_reprs, dim=-1)
                            HX = torch.matmul(permuted_reprs, eqv_score.unsqueeze(dim=-1)).squeeze()

            else:
                eqv_idx = torch.argmax(eqv_logit, dim=1)
                
                if self.args.eqv_loss_type == 'shift':    
                    batch_perm = self.permute_tensor[eqv_idx].unsqueeze(1).expand(-1,c//self.order,-1)
                    FX_re = FX.reshape([b, c//self.order, self.order])
                    FX_re = FX_re.gather(2, batch_perm)
                    HX = FX_re.reshape([b,c])
                else:
                    trans = (4 - eqv_idx) % self.order
                
                    if self.args.R_choices == 'linear':  
                        FX_all = [FX]
                        for _ in range(self.order-1):
                            FX_all.append(self.equi_transform(FX_all[-1]))
                        FX_stack = torch.stack(FX_all, dim=1)
                        trans = trans.view(-1, 1, 1).expand(-1, -1, FX.size()[-1])
                        HX = torch.gather(FX_stack, dim=1, index=trans).squeeze(1)
                    
                    elif self.args.R_choices == 'multi_linear_align':
                        HX = self.equi_transform(FX, trans)
                            
            out = self.fc(HX)

        return out
    
    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()
