import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
# Torchvision
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34, resnet50
from torch.utils.tensorboard import SummaryWriter
import argparse

from dataset import *
#from mymodels import *
from DivideMix.PreResNet import *
import models
from losses import *
from utils import *

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--data_type', default='cifar-10', type=str)
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--max_epoch', default=200, type=int)
parser.add_argument('--n_class', default=10, type=int)
parser.add_argument('--loss_name', default='ce', type=str)
parser.add_argument('--noise_ratio', default=40, type=int, help='noise ratio')
parser.add_argument('--asym', default=False, type=bool)
parser.add_argument('--CGO', default=True, type=bool)
parser.add_argument('--lmd', default=0.99, type=float, help='lambda of temporal assembling')
parser.add_argument('--constrain_layers', default=3, type=int)
parser.add_argument('--T', default=0.1, type=float, help='sharpening temperature')
args = parser.parse_args()

lr = args.lr
n_class = args.n_class
batch_size = args.batch_size
data_type = args.data_type
max_epoch = args.max_epoch
noise_ratio = args.noise_ratio
loss_name = args.loss_name
CGO_mode = 'cgo' if args.CGO else 'none'
lmd = args.lmd
constrain_layers = args.constrain_layers
total_result = []
# parameters
ALPHA = 0.01
padding = 2
img_size = 32
img_dim = 3
constrain_threshold = 0.5
constrain_flag = False
init_tau = 0.5
std_exp = 1

#for init_tau in [0.5, 0.3, 0.1, 0.05, 0.01]:  for std_exp in [1, 0.5, 0.3, 0.1, 0.05, 0.01]:
mt = torch.zeros((max_epoch, n_class, n_class))
mv = torch.zeros((max_epoch, n_class, n_class))
for nr, cr in zip([0.8], [0.2]):
    for E in [15]:
        seed = 0
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        #CGO_mode = 'none' # none, constrain, rectify
        checkpoint_path = 'checkpoints/{2}_{0}_{3}_noise{1}_.pth'.format(loss_name, noise_ratio, data_type, CGO_mode)
        train_idx = 0
        while os.path.exists('log/{0:s}_noise-{1:d}_{2:s}_{3:s}_{5:s}_idx{4:d}_log_.txt'.format(data_type, noise_ratio, 'cnn', CGO_mode, train_idx, loss_name)):
            train_idx += 1
        train_log=open('log/{0:s}_noise-{1:d}_{2:s}_{3:s}_{5:s}_idx{4:d}_log_.txt'.format(data_type, noise_ratio, 'cnn', CGO_mode, train_idx, loss_name),'w')

        # dataset
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=padding),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.RandomCrop(img_size, padding=padding),
            transforms.ToTensor(),
        ])
        X_train, y_train, y_train_clean, X_test, y_test = get_data(data_type, noise_ratio=noise_ratio, asym=args.asym, random_shuffle=False)
        if not os.path.exists('settings/%s_valid_idx.npy' % (data_type)):
            valid_idx = np.random.choice(np.arange(X_train.shape[0]), X_train.shape[0]//10, replace=False)
            np.save('settings/%s_valid_idx.npy' % (data_type), valid_idx)
        valid_idx = np.load('settings/%s_valid_idx.npy' % (data_type))
        total_num = X_train.shape[0]
        X_valid = X_train[valid_idx]
        y_valid = y_train[valid_idx]
        y_valid_clean = y_train_clean[valid_idx]
        X_train = X_train[np.setdiff1d(np.arange(total_num), valid_idx)]
        y_train = y_train[np.setdiff1d(np.arange(total_num), valid_idx)]
        #y_train = np.load('rec_label.npy')
        y_train_clean = y_train_clean[np.setdiff1d(np.arange(total_num), valid_idx)]
        clean_set = np.argmax(y_train, axis=1)==np.argmax(y_train_clean, axis=1)
        corrupt_set = np.argmax(y_train, axis=1)!=np.argmax(y_train_clean, axis=1)
        clean_datasets = MyDataset(X_train[clean_set], y_train[clean_set], y_train_clean[clean_set], require_index=True, transform=transform_train)
        clean_loader = DataLoader(clean_datasets, batch_size=batch_size, shuffle=True, num_workers=0)
        if noise_ratio > 0:
            corrupt_datasets = MyDataset(X_train[corrupt_set], y_train[corrupt_set], y_train_clean[corrupt_set], require_index=True, transform=transform_train)
            corrupt_loader = DataLoader(corrupt_datasets, batch_size=batch_size, shuffle=True, num_workers=0)
        all_datasets = MyDataset(X_train, y_train, y_train_clean, require_index=True, transform=transform_train, max_epoch=max_epoch)
        all_loader = DataLoader(all_datasets, batch_size=batch_size, shuffle=True, num_workers=0)
        all_clean_datasets = MyDataset(X_train, y_train_clean, y_train_clean, require_index=True, transform=transform_train, max_epoch=max_epoch)
        all_clean_loader = DataLoader(all_clean_datasets, batch_size=batch_size, shuffle=True, num_workers=0)
        valid_datasets = MyDataset(X_valid, y_valid, y_valid_clean, transform = transform_test)
        valid_loader = DataLoader(valid_datasets, batch_size=batch_size, shuffle=False, num_workers=0)
        test_datasets = MyDataset(X_test, y_test, y_test, transform = transform_test)
        test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=False, num_workers=0)

        #model = models.load_model(model_name='vgg16', in_channels=img_dim, num_classes=n_class).cuda()
        model = ResNet18(num_classes=n_class).cuda()
        opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        #opt = optim.Adam(model.parameters(), lr=lr)
        scheduler = lr_scheduler.CosineAnnealingLR(opt, T_max=max_epoch)
        #scheduler = lr_scheduler.ReduceLROnPlateau(opt, mode='max', factor=0.2, patience=20, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

        total_valid_acc_list = []
        total_valid_loss_list = []
        total_acc_noisy = []
        total_acc_true = []
        total_memorization = []
        total_rectification = []
        dataset_type = 'all'
        if dataset_type == 'clean':
            train_datasets = clean_datasets
            train_loader = clean_loader
        elif dataset_type == 'corrupt':
            train_datasets = corrupt_datasets
            train_loader = corrupt_loader
        elif dataset_type == 'all':
            train_datasets = all_datasets
            train_loader = all_loader
            
        total_train_p = torch.zeros(max_epoch, train_datasets.__len__()).float()
        total_epoch_loss = torch.zeros(max_epoch//10, train_datasets.__len__()).float()
        total_epoch_noisy = torch.zeros(max_epoch//10, train_datasets.__len__()).bool()

        #all_datasets.random = True
        #rand_index = np.random.choice(np.arange(all_datasets.__len__()), int(all_datasets.__len__()*0.1), replace=False)
        #all_datasets.targets_noisy_seq[rand_index] = all_datasets.targets_noisy_seq[rand_index] * (n_class-1) + torch.ones_like(all_datasets.targets_noisy_seq[rand_index])

        for epoch in range(max_epoch):
            # train
            model.train()
            print('\ncnn train:', epoch)
            cls_loss_list = []
            grad_loss_list = []
            cls_acc_list = []
            true_acc_list = []
            cls_mem_list = []
            cls_rec_list = []
            map_mem_list = []
            map_rec_list = []
            g_pred_acc = 0
            g_pred_num = 0.1
            a_pred_acc = 0
            a_pred_num = 0.1
            cons_a = 0
            cons_n = 0.1
            #RT = min(1, abs(epoch-E)/E * noise_ratio/100 + 1-noise_ratio/100)
            tau = init_tau * (1 - epoch / max_epoch)
            std_w = (epoch / max_epoch) ** std_exp
                            
            rec_ratio = min(epoch/100, 1)
            cons_ratio = 1 - min(noise_ratio, noise_ratio/10*max(0, epoch))/100

            #all_datasets.update_ramdom_noise(noise_ratio=nr, clean_ratio=cr)
            if epoch > E:
                all_datasets.random = True

            total_train_label = torch.zeros(train_datasets.__len__()).long()
            total_train_pred = torch.zeros(train_datasets.__len__()).long()
            total_train_idx = torch.zeros(train_datasets.__len__()).long()
            
            #cur_seq_len = len(train_datasets.targets_noisy_seq)

            for batch_idx, data in enumerate(train_loader):
                target_layers = [3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
                #output_layers = ([model.features[i] for i in target_layers[-constrain_layers:]])
                #input_layers = ([model.features[i-1] for i in target_layers[-constrain_layers:]])
                input_layers = ([model.layer4[0].conv1, model.layer4[0].conv2, model.layer4[1].conv1, model.layer4[1].conv2])
                input_modules = [HookModule(model=model, module=layer) for layer in input_layers]
                #output_modules = [HookModule(model=model, module=layer) for layer in output_layers]

                train_image, train_label, train_label_true, train_index = data
                train_image = Variable(train_image.float(), requires_grad=True).cuda()
                train_label = Variable(torch.eye(10)[train_label]).cuda()
                train_label_true = Variable(train_label_true).cuda()

                cls_logits = model(train_image)
                cls_output = nn.Softmax(dim=-1)(cls_logits)
                cls_output = (cls_output+1e-4)/(1+2e-4)

                conf_threshold = torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).sort()[0][int((1-cons_ratio)*cls_output.shape[0])]
                diff_mask = torch.ones(train_label.shape[0]).bool().cuda()#(train_label.argmax(dim=1) != cls_output.argmax(dim=1)).detach()
                
                grad_loss = ag_loss = torch.zeros(1)
                g_pred = torch.zeros(train_label.shape[0]).long().cuda()
                a_pred = torch.zeros(train_label.shape[0]).long().cuda()
                cons_mask = torch.zeros(train_label.shape[0]).bool().cuda()
                #conf_threshold = torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).sort()[0][int((1-RT)*cls_output.shape[0])]
                #cons_mask = (torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).detach() < conf_threshold)
            
                if (~cons_mask).sum() > 0:
                    if loss_name == 'ce': #or epoch < 20:
                        cls_loss = nn.NLLLoss()(torch.log(cls_output)[~cons_mask], torch.argmax(train_label[~cons_mask], dim=1))
                        cls_loss_true = nn.NLLLoss()(torch.log(cls_output)[~cons_mask], torch.argmax(train_label_true[~cons_mask], dim=1))
                        pred_loss = nn.NLLLoss()(torch.log(cls_output)[~cons_mask], torch.argmax(cls_output[~cons_mask].detach(), dim=1))
                    elif loss_name == 'sl':
                        cls_loss = SLLoss(cls_output[~cons_mask], train_label[~cons_mask], alpha=0.1, beta=1.0, n_class=n_class)
                        pred_loss = SLLoss(cls_output[~cons_mask], torch.argmax(cls_output[~cons_mask].detach(), dim=1), alpha=0.1, beta=1.0, n_class=n_class)
                    elif loss_name == 'gce':
                        cls_loss = GCELoss(cls_output[~cons_mask], train_label[~cons_mask], 0.9)
                        pred_loss = GCELoss(cls_output[~cons_mask], cls_output[~cons_mask].detach(), 0.9)
                        grad_loss = GCELoss(cls_output[train_label.argmax(dim=1)!=train_label_true.argmax(dim=1)], train_label_true[train_label.argmax(dim=1)!=train_label_true.argmax(dim=1)], 0.9)

                    total_activations = [m.activations.detach() for m in input_modules]
                    if epoch > E and (CGO_mode != 'none' and CGO_mode != 'smallloss'):
                        #all_datasets.targets_seq_update(epoch*2+12, train_index, all_datasets.targets_true[train_index])
                        pred_act = sum([(norm(a, norm_type='length') @ norm(a_routes, norm_type='length').T) for a_routes, a in zip(total_a_routes[-constrain_layers:], total_activations[-constrain_layers:])]) / constrain_layers
                        a_pred = pred_act.argmax(dim=1)
                        pred_act = pred_act.cpu() * torch.eye(n_class)[a_pred]
                        rand_class = torch.randint(0, n_class, (train_index.shape[0],))
                        all_datasets.targets_seq_update(train_index, all_datasets.targets_init[train_index].float() * (1 - max(0, epoch-E)/(max_epoch-E)))
                        #all_datasets.targets_seq_update(train_index, pred_act * (1 - max(0, epoch-E)/(max_epoch-E) * 0.5))
                        all_datasets.targets_seq_update(train_index, pred_act * min((max(0, epoch-E)) / 10, 2.0))
                        #all_datasets.targets_seq_update(train_index, torch.eye(n_class)[rand_class] * (epoch - 20) / (max_epoch//2))
                        all_datasets.targets_seq_update(train_index, torch.ones(train_index.shape[0], n_class).float() / n_class * max(0, epoch-E) / ((max_epoch-E)) * 1.0)

                    conf_set = torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).argsort()[-train_label.shape[0]//2:]
                    a_routes = [torch.matmul(a[conf_set].transpose(0,3), train_label[conf_set].float()).transpose(0,3) for a in total_activations]
                    if epoch == 0 and batch_idx == 0:
                        total_a_routes = [a for a in a_routes]
                    else:
                        for layer_idx in range(len(total_a_routes)):
                            total_a_routes[layer_idx] = total_a_routes[layer_idx]*lmd + a_routes[layer_idx]*(1-lmd)
                        
                    opt.zero_grad()
                    cls_loss.backward()
                    opt.step()
                acc_noisy = (torch.argmax(cls_output, dim=1) == all_datasets.targets_init[train_index].argmax(dim=1).cuda()).float().mean()
                acc_true = (torch.argmax(cls_output, dim=1) == torch.argmax(train_label_true, dim=1)).float().mean()
                cls_mem = ((torch.argmax(cls_output, dim=1) == all_datasets.targets_init[train_index].argmax(dim=1).cuda()) & (
                            all_datasets.targets_init[train_index].argmax(dim=1).cuda() != torch.argmax(train_label_true, dim=1))).float().mean()
                cls_rec = ((torch.argmax(cls_output, dim=1) == torch.argmax(train_label_true, dim=1)) & (
                            all_datasets.targets_init[train_index].argmax(dim=1).cuda() != torch.argmax(train_label_true, dim=1))).float().mean()
                
                g_pred_acc += ((g_pred == train_label_true.argmax(dim=1))).sum().item()
                g_pred_num += g_pred.shape[0]
                a_pred_acc += ((a_pred == train_label_true.argmax(dim=1))).sum().item()
                a_pred_num += a_pred.shape[0]
                cons_a += (torch.argmax(train_label, dim=1) == torch.argmax(train_label_true, dim=1))[~cons_mask].float().sum().item()
                cons_n += (~cons_mask).sum().item()
                total_train_label[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(train_label, dim=1).detach().cpu()
                total_train_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(cls_output, dim=1).detach().cpu()
                total_train_idx[batch_idx * batch_size:(batch_idx + 1) * batch_size] = train_index.detach().cpu()
                #total_train_p[epoch, train_index] = nn.NLLLoss(reduction='none')(cls_output, torch.argmax(train_label, dim=1)).detach().cpu()
                
                cls_loss_list.append(cls_loss.mean().item())
                grad_loss_list.append(ag_loss.mean().item())
                cls_acc_list.append(acc_noisy.item())
                true_acc_list.append(acc_true.item())
                cls_mem_list.append(cls_mem.item())
                cls_rec_list.append(cls_rec.item())
                #map_mem_list.append(map_mem.item())
                #map_rec_list.append(map_rec.item())

                print(
                    '\rbatch:{0}/{1} lr:{2:.4f} cls loss:{3:.4f} grad loss:{12:.4f} acc:{4:.4f} true:{5:.4f} mem:{6:.4f} rec:{7:.4f} route_pred(g/a):{8:.4f}/{9:.4f} consnum:{11:d}:{10:.4f}'.format(
                        batch_idx, train_datasets.__len__() // batch_size,
                        opt.param_groups[0]['lr'],
                        sum(cls_loss_list) / len(cls_loss_list),
                        sum(cls_acc_list) / len(cls_acc_list),
                        sum(true_acc_list) / len(true_acc_list),
                        sum(cls_mem_list) / len(cls_mem_list),
                        sum(cls_rec_list) / len(cls_rec_list),
                        g_pred_acc / g_pred_num, a_pred_acc / a_pred_num, cons_a / cons_n, int(cons_n),
                        sum(grad_loss_list) / len(grad_loss_list)
                        ), end='')
            
            train_log.write('Epoch:{0:d} acc:{1:.4f} loss:{2:.4f} nloss:{5:.4f} mem:{3:.4f} rec:{4:.4f} '.format(epoch, sum(true_acc_list)/len(true_acc_list), sum(cls_loss_list)/len(cls_loss_list), sum(cls_mem_list)/len(cls_mem_list), sum(cls_rec_list)/len(cls_rec_list), sum(grad_loss_list)/len(grad_loss_list)))
            train_log.flush()  
            print('\ndataset noise:', (train_datasets.targets_true.argmax(dim=1) != train_datasets.targets_noisy_seq.argmax(dim=1)).float().mean())
            #scheduler.step()
            total_acc_noisy.append(sum(cls_acc_list) / len(cls_acc_list))
            total_acc_true.append(sum(true_acc_list) / len(true_acc_list))
            total_memorization.append(sum(cls_mem_list) / len(cls_mem_list))
            total_rectification.append(sum(cls_rec_list) / len(cls_rec_list))

            # validation
            #with torch.no_grad():
            model.eval()
            print('fea val:', epoch)
            cls_loss_list = []
            cls_acc_list = []
            total_label = torch.zeros(valid_datasets.__len__()).long()
            total_pred = torch.zeros(valid_datasets.__len__()).long()
            total_g_pred = torch.zeros(valid_datasets.__len__()).long()
            total_a_pred = torch.zeros(valid_datasets.__len__()).long()

            for batch_idx, data in enumerate(valid_loader): 
                valid_image, valid_label, valid_label_true = data
                valid_image = Variable(valid_image.float(), requires_grad=True).cuda()
                valid_label = Variable(valid_label).cuda()
                valid_label_true = Variable(valid_label_true).cuda()

                if loss_name == 'cm':
                    cls_output, noise_output, valid_cm = model(valid_image)
                else:
                    cls_logits = model(valid_image)
                    cls_output = nn.Softmax(dim=-1)(cls_logits)
                if loss_name == 'ce':
                    cls_loss = nn.NLLLoss()(torch.log(cls_output), torch.argmax(valid_label_true, dim=1))
                elif loss_name == 'sl':
                    cls_loss = SLLoss(cls_output, valid_label_true, 0.1, 1.0).mean()
                elif loss_name == 'tl':
                    cls_loss = TLoss(cls_output, valid_label_true, 0.5, 0.1).mean()
                elif loss_name == 'cm':
                    cls_loss = nn.NLLLoss()(torch.log(cls_output), torch.argmax(valid_label_true, dim=1))
                    cm_loss = -torch.log(torch.diag(torch.mean(valid_cm, dim=0))).mean()
                    cls_loss = cls_loss + cm_loss * 0.1
                cls_acc = (torch.argmax(cls_output, dim=1) == torch.argmax(valid_label_true, dim=1)).float().mean()
                total_label[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(valid_label_true, dim=1).detach().cpu()
                total_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(cls_output, dim=1).detach().cpu()

                cls_loss_list.append(cls_loss.item())
                cls_acc_list.append(cls_acc.item())

                print('\rbatch:{0}/{1} cls loss:{2:.4f} acc:{3:.4f}'.format(
                        batch_idx, valid_datasets.__len__() // batch_size, 
                        sum(cls_loss_list) / len(cls_loss_list), 
                        (total_label == total_pred).float().mean()
                        ), end='')
            train_log.write('Eval acc:{0:.4f}\n'.format((total_label == total_pred).float().mean()))
            train_log.flush()  
            print()
            
            total_valid_acc_list.append((total_label == total_pred).float().mean())
            if epoch == 0 or total_valid_acc_list[-1] == max(total_valid_acc_list):
                torch.save(model.state_dict(), checkpoint_path)
                print('model update')
            #scheduler.step((total_label == total_pred).float().mean())
            scheduler.step()

        # test
        model.load_state_dict(torch.load(checkpoint_path))
        with torch.no_grad():
            model.eval()
            print('\nfea test:', epoch)
            cls_loss_list = []
            cls_acc_list = []
            total_label = torch.zeros(test_datasets.__len__()).long()
            total_pred = torch.zeros(test_datasets.__len__()).long()
            total_g_pred = torch.zeros(test_datasets.__len__()).long()
            total_a_pred = torch.zeros(test_datasets.__len__()).long()

            for batch_idx, data in enumerate(test_loader): 
                test_image, _, test_label_true = data
                test_image = Variable(test_image.float(), requires_grad=True).cuda()
                test_label_true = Variable(test_label_true).cuda()

                cls_output = model(test_image)
                cls_output = nn.Softmax(dim=-1)(cls_output)
                if loss_name == 'ce':
                    cls_loss = nn.NLLLoss()(torch.log(cls_output), torch.argmax(test_label_true, dim=1))
                elif loss_name == 'sl':
                    cls_loss = SLLoss(cls_output, test_label_true, 0.1, 1.0).mean()
                elif loss_name == 'tl':
                    cls_loss = TLoss(cls_output, test_label_true, 0.5, 0.1).mean()

                cls_acc = (torch.argmax(cls_output, dim=1) == torch.argmax(test_label_true, dim=1)).float().mean()
                total_label[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(test_label_true, dim=1).detach().cpu()
                total_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(cls_output, dim=1).detach().cpu()

                cls_loss_list.append(cls_loss.item())
                cls_acc_list.append(cls_acc.item())

            test_acc = (total_label == total_pred).float().mean()
            total_result.append(test_acc.item())
            print(total_result)