import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from sklearn.model_selection import train_test_split
# Torchvision
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34, resnet50
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter

from dataset import *
from mymodels import *
import models 
from losses import *
from utils import *
from DivideMix.PreResNet import *

data_type = 'cifar-10'
if data_type == 'cifar-10':
    n_class = 10
    model_type = 'layer6'
    ALPHA = 0.01
    lr = 0.01
    batch_size = 128
    padding = 2
    img_size = 32
    img_dim = 3
elif data_type =='cifar-100':
    n_class = 100
    model_type = 'layer6'
    ALPHA = 0.01
    lr = 0.01
    batch_size = 128
    padding = 2
    img_size = 32
    img_dim = 3
elif data_type == 'mnist':
    n_class = 10
    model_type = 'layer6'
    ALPHA = 0.01
    lr = 0.01
    batch_size = 128
    padding = 0
    img_size = 28
    img_dim = 1
elif data_type == 'fashion':
    n_class = 10
    model_type = 'layer6'
    ALPHA = 0.01
    lr = 0.01
    batch_size = 128
    padding = 0
    img_size = 28
    img_dim = 1

lmd = 0.99
constrain_layers = 3
loss_name = 'ce'
rand_label = True
total_result = []
for E in [15]:
    for lr in [0.01]:
        opt_type = 'sgd'
        max_epoch = 200
        noise_ratio = 40
        T_max = max_epoch
        flag = False
        need_board = False
        #ND_mode = 'sam' # none; constrain; rectify
        ND_mode = 'gss'#lambda{0:s}'.format(str(lmd))
        constrain_threshold = 0.5
        constrain_flag = False
        train_idx = 0
        while os.path.exists('log/{0:s}_noise-{1:d}_{2:s}_{3:s}_{5:s}_idx{4:d}_log.txt'.format(data_type, noise_ratio, 'coteaching', ND_mode, train_idx, loss_name)):
            train_idx += 1
        train_log=open('log/{0:s}_noise-{1:d}_{2:s}_{3:s}_{5:s}_idx{4:d}_log.txt'.format(data_type, noise_ratio, 'coteaching', ND_mode, train_idx, loss_name),'w')


        if need_board: writer = SummaryWriter()

        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=padding),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

        print('data_type:{0}, model_type:{1}, ALPHA: {2}, padding: {3}, batch size:{4}, lr:{5}, loss:{6}'.format(data_type, model_type, ALPHA, padding, batch_size, lr, loss_name))
        X_train, y_train, y_train_clean, X_test, y_test = get_data(data_type, noise_ratio=noise_ratio, asym=False, random_shuffle=False)
        print('Noise rate:', (np.argmax(y_train, axis=-1) != np.argmax(y_train_clean, axis=-1)).mean())

        if os.path.exists('data/%s_valid_idx.npy' % (data_type)):
            valid_idx = np.load('data/%s_valid_idx.npy' % (data_type))
        else:
            valid_idx = np.random.choice(X_train.shape[0], X_train.shape[0]//10, replace=False)
            np.save('data/%s_valid_idx.npy' % (data_type), valid_idx)
        e2h_idx = np.load('data/easy2hard.npy')
        total_num = X_train.shape[0]
        X_valid = X_train[valid_idx]
        y_valid = y_train_clean[valid_idx]
        X_train = X_train[np.setdiff1d(np.arange(total_num), valid_idx)][e2h_idx[:]]
        y_train = y_train[np.setdiff1d(np.arange(total_num), valid_idx)][e2h_idx[:]]
        y_train_clean = y_train_clean[np.setdiff1d(np.arange(total_num), valid_idx)][e2h_idx[:]]
        train_datasets = CoDataset(X_train, y_train, y_train_clean, require_index=True, transform=transform_train)
        train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True, num_workers=0)
        valid_datasets = MyDataset(X_valid, y_valid, y_valid, transform = transform_test)
        valid_loader = DataLoader(valid_datasets, batch_size=batch_size, shuffle=False, num_workers=0)
        test_datasets = MyDataset(X_test, y_test, y_test, transform = transform_test)
        test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=False, num_workers=0)

        valid_acc_list = []
        #model_a = models.load_model(model_name='vgg16', in_channels=3, num_classes=10).cuda()
        #model_b = models.load_model(model_name='vgg16', in_channels=3, num_classes=10).cuda()
        model_a = ResNet18(num_classes=n_class).cuda()
        model_b = ResNet18(num_classes=n_class).cuda()

        if opt_type == 'sgd':
            opt_a = optim.SGD(model_a.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
            opt_b = optim.SGD(model_b.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        elif opt_type == 'adam':
            opt_a = optim.Adam(model_a.parameters(), lr=lr)
            opt_b = optim.Adam(model_b.parameters(), lr=lr)
        scheduler_a = lr_scheduler.CosineAnnealingLR(opt_a, T_max = T_max+20)
        scheduler_b = lr_scheduler.CosineAnnealingLR(opt_b, T_max = T_max+20)
        #scheduler_a = lr_scheduler.ReduceLROnPlateau(opt_a, mode='max', factor=0.2, patience=20, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
        #scheduler_b = lr_scheduler.ReduceLROnPlateau(opt_b, mode='max', factor=0.2, patience=20, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

        total_valid_acc_list = []
        total_valid_loss_list = []
        total_acc_noisy = []
        total_acc_true = []
        total_memorization = []
        total_rectification = []
        total_train_p = torch.zeros(max_epoch, train_datasets.__len__()).float()
        checkpoint_path_a = 'checkpoints/modelA_{2}_{0}_{3}_noise{1}.pth'.format(loss_name, noise_ratio, data_type, ND_mode)
        checkpoint_path_b = 'checkpoints/modelB_{2}_{0}_{3}_noise{1}.pth'.format(loss_name, noise_ratio, data_type, ND_mode)

        all_g = torch.zeros((4, max_epoch, n_class, 10, 512, 4, 4)).float().cuda()
        all_a = torch.zeros((4, max_epoch, 10, 512, 4, 4)).float().cuda()

        for epoch in range(max_epoch):
            # train
            model_a.train()
            model_b.train()
            print('\ncnn train:', epoch)
            cls_loss_list = []
            cls_acc_list = []
            true_acc_list = []
            mem_list = []
            rec_list = []
            rec_num = 0
            all_num = 0.1
            cons_a = 0
            cons_n = 0.1
            rec_ratio = min(epoch/100, 1)
            rec_RT = 0#(train_datasets.targets_noisy_seq[-1].argmax(dim=1) != train_datasets.targets_noisy_seq[0].argmax(dim=1)).float().mean()
            if not rand_label or epoch < E+10:
                RT = max(0, (10 - epoch)/10 * noise_ratio/100) + 1-noise_ratio/100
            else:
                RT = min(1, 1-noise_ratio/100 + (epoch - E-10) * noise_ratio/100/30)
            tau = 0.1 * (1 - epoch / max_epoch)
            
            total_train_label = torch.zeros(train_datasets.__len__()).long()
            total_train_pred_a = torch.zeros(train_datasets.__len__()).long()
            total_train_pred_b = torch.zeros(train_datasets.__len__()).long()
            total_train_idx = torch.zeros(train_datasets.__len__()).long()
            
            if rand_label and epoch >= E:
                train_datasets.random = True

            seq_len_a = len(train_datasets.targets_seq_a)
            seq_len_b = len(train_datasets.targets_seq_b)

            for batch_idx, data in enumerate(train_loader):
                target_layers = [3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
                input_layers_a = ([model_a.layer4[0].conv1, model_a.layer4[0].conv2, model_a.layer4[1].conv1, model_a.layer4[1].conv2])
                input_layers_b = ([model_a.layer4[0].conv1, model_a.layer4[0].conv2, model_a.layer4[1].conv1, model_a.layer4[1].conv2])
                #input_layers_a = ([model_a.features[i-1] for i in target_layers[-3:]])
                #input_layers_b = ([model_b.features[i-1] for i in target_layers[-3:]])
                input_modules_a = [HookModule(model=model_a, module=layer) for layer in input_layers_a]
                input_modules_b = [HookModule(model=model_b, module=layer) for layer in input_layers_b]

                train_image, train_label_a, train_label_b, train_label_true, train_index = data
                train_image = Variable(train_image.float(), requires_grad=True).cuda()
                train_label_a = Variable(torch.eye(n_class)[train_label_a]).cuda()
                train_label_b = Variable(torch.eye(n_class)[train_label_b]).cuda()
                train_label_true = Variable(train_label_true).cuda()        

                grad_loss_a = torch.zeros(1)
                grad_loss_b = torch.zeros(1)
                g_pred_a = torch.zeros(train_label_a.shape[0]).long().cuda()
                a_pred_a = torch.zeros(train_label_a.shape[0]).long().cuda()
                g_pred_b = torch.zeros(train_label_a.shape[0]).long().cuda()
                a_pred_b = torch.zeros(train_label_a.shape[0]).long().cuda()
                map_pred_a = torch.zeros(train_label_a.shape[0]).long().cuda()
                cons_mask_a = torch.zeros(train_label_a.shape[0]).bool().cuda()
                map_pred_b = torch.zeros(train_label_a.shape[0]).long().cuda()
                cons_mask_b = torch.zeros(train_label_a.shape[0]).bool().cuda()

                logits_a = model_a(train_image)
                logits_b = model_b(train_image)
                output_a = (nn.Softmax(dim=1)(logits_a) + 1e-4) / (1+1e-4)
                output_b = (nn.Softmax(dim=1)(logits_b) + 1e-4) / (1+1e-4)
                
                # Coteaching
                diff_mask = torch.ones((output_a.shape[0])).bool().cuda()
                # Coteaching +
                #diff_mask = output_a.argmax(dim=1) != output_b.argmax(dim=1)

                loss_a = nn.NLLLoss(reduction='none')(torch.log(output_a), torch.argmax(train_label_a, dim=1))[diff_mask]
                loss_b = nn.NLLLoss(reduction='none')(torch.log(output_b), torch.argmax(train_label_b, dim=1))[diff_mask]
                pred_loss_a = nn.NLLLoss(reduction='none')(torch.log(output_a), torch.argmax(output_a.detach(), dim=1))[diff_mask]
                pred_loss_b = nn.NLLLoss(reduction='none')(torch.log(output_b), torch.argmax(output_a.detach(), dim=1))[diff_mask]

                RT_num = int(RT * diff_mask.sum())-1
                smallloss_mask_a = loss_a <= loss_a.sort()[0][RT_num] 
                smallloss_mask_b = loss_b <= loss_b.sort()[0][RT_num]
                masked_loss_a = loss_a[smallloss_mask_b].mean()
                masked_loss_b = loss_b[smallloss_mask_a].mean()
                masked_pred_loss_a = pred_loss_a[smallloss_mask_b].mean()
                masked_pred_loss_b = pred_loss_b[smallloss_mask_a].mean()

                diff_mask_a = (train_label_a.argmax(dim=1) != output_a.argmax(dim=1)).detach()#torch.ones(train_label.shape[0]).bool().cuda()
                diff_mask_b = (train_label_b.argmax(dim=1) != output_b.argmax(dim=1)).detach()#torch.ones(train_label.shape[0]).bool().cuda()
                
                if rand_label:
                    total_activations_a = [m.activations.detach() for m in input_modules_a]
                    total_activations_b = [m.activations.detach() for m in input_modules_b]
                    if epoch >= E:
                        #all_datasets.targets_seq_update(epoch*2+12, train_index, all_datasets.targets_true[train_index])
                        pred_act_a = sum([(norm(a, norm_type='length') @ norm(a_routes, norm_type='length').T) for a_routes, a in zip(total_a_routes_a[-constrain_layers:], total_activations_a[-constrain_layers:])]) / constrain_layers
                        pred_act_b = sum([(norm(a, norm_type='length') @ norm(a_routes, norm_type='length').T) for a_routes, a in zip(total_a_routes_b[-constrain_layers:], total_activations_b[-constrain_layers:])]) / constrain_layers
                        a_pred_a = pred_act_a.argmax(dim=1)
                        a_pred_b = pred_act_b.argmax(dim=1)
                        rand_class_a = torch.randint(0, n_class, (train_index.shape[0],))
                        rand_class_b = torch.randint(0, n_class, (train_index.shape[0],))
                        train_datasets.targets_seq_update(0, train_index, train_datasets.targets_init[train_index] * (1 - max(0, epoch-E)/(max_epoch-E)))
                        train_datasets.targets_seq_update(1, train_index, train_datasets.targets_init[train_index] * (1 - max(0, epoch-E)/(max_epoch-E)))
                        train_datasets.targets_seq_update(0, train_index, torch.eye(n_class)[a_pred_b] * min((max(0, epoch-E)) / 10, 2.0))
                        train_datasets.targets_seq_update(1, train_index, torch.eye(n_class)[a_pred_a] * min((max(0, epoch-E)) / 10, 2.0))
                        #train_datasets.targets_seq_update(0, train_index, torch.eye(n_class)[rand_class_a] * (1 - max(0, epoch-E)/(max_epoch-E)))
                        #train_datasets.targets_seq_update(1, train_index, torch.eye(n_class)[rand_class_b] * (1 - max(0, epoch-E)/(max_epoch-E)))
                        train_datasets.targets_seq_update(0, train_index, torch.ones(train_image.shape[0], n_class).float() / n_class * max(0, epoch-E) / ((max_epoch-E)) * 0.5)
                        train_datasets.targets_seq_update(1, train_index, torch.ones(train_image.shape[0], n_class).float() / n_class * max(0, epoch-E) / ((max_epoch-E)) * 0.5)
                    
                    conf_set_a = torch.nn.NLLLoss(reduction='none')(-output_a, train_label_a.argmax(dim=1)).argsort()[-train_label_a.shape[0]//2:]
                    conf_set_b = torch.nn.NLLLoss(reduction='none')(-output_b, train_label_b.argmax(dim=1)).argsort()[-train_label_b.shape[0]//2:]
                    a_routes_a = [torch.matmul(a[conf_set_a].transpose(0,3), train_label_a[conf_set_a].float()).transpose(0,3) for a in total_activations_a]
                    a_routes_b = [torch.matmul(a[conf_set_b].transpose(0,3), train_label_b[conf_set_b].float()).transpose(0,3) for a in total_activations_b]
                    if epoch == 0 and batch_idx == 0:
                        total_a_routes_a = [a for a in a_routes_a]
                        total_a_routes_b = [a for a in a_routes_b]
                    else:
                        for layer_idx in range(len(total_a_routes_a)):
                            total_a_routes_a[layer_idx] = total_a_routes_a[layer_idx]*lmd + a_routes_a[layer_idx]*(1-lmd)
                            total_a_routes_b[layer_idx] = total_a_routes_b[layer_idx]*lmd + a_routes_b[layer_idx]*(1-lmd)

                acc_noisy = (torch.argmax(output_a, dim=1) == torch.argmax(train_label_a, dim=1)).float().mean()
                acc_true = (torch.argmax(output_a, dim=1) == torch.argmax(train_label_true, dim=1)).float().mean()
                memorization = ((torch.argmax(output_a, dim=1) == torch.argmax(train_label_a, dim=1)) & (
                            torch.argmax(train_label_a, dim=1) != torch.argmax(train_label_true, dim=1))).float().mean()
                rectification = ((torch.argmax(output_a, dim=1) == torch.argmax(train_label_true, dim=1)) & (
                            torch.argmax(train_label_a, dim=1) != torch.argmax(train_label_true, dim=1))).float().mean()
                total_train_label[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(train_label_a, dim=1).detach().cpu()
                total_train_pred_a[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(output_a, dim=1).detach().cpu()
                total_train_pred_b[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(output_b, dim=1).detach().cpu()
                total_train_idx[batch_idx * batch_size:(batch_idx + 1) * batch_size] = train_index.detach().cpu()
                total_train_p[epoch, train_index] = \
                    nn.NLLLoss(reduction='none')(output_a, torch.argmax(train_label_a, dim=1)).detach().cpu()

                #cls_output.retain_grad()
                opt_a.zero_grad()
                masked_loss_a.backward()
                opt_a.step()
                opt_b.zero_grad()
                masked_loss_b.backward()
                opt_b.step()

                cons_a += (torch.argmax(train_label_a, dim=1) == torch.argmax(train_label_true, dim=1)).float().sum().item()
                cons_n += train_label_a.shape[0]
                cls_loss_list.append(masked_loss_a.mean().item())
                cls_acc_list.append(acc_noisy.item())
                true_acc_list.append(acc_true.item())
                mem_list.append(memorization.item())
                rec_list.append(rectification.item())

                print('\rbatch:{0}/{1} lr:{2:.4f} cls loss:{3:.4f} acc:{4:.4f} true:{5:.4f} mem:{6:.4f} rec:{7:.4f} clean rate:{8:.4f}'.format(
                        batch_idx, train_datasets.__len__() // batch_size,
                        opt_a.param_groups[0]['lr'],
                        sum(cls_loss_list) / len(cls_loss_list),
                        sum(cls_acc_list) / len(cls_acc_list),
                        sum(true_acc_list) / len(true_acc_list),
                        sum(mem_list) / len(mem_list),
                        sum(rec_list) / len(rec_list),
                        cons_a / cons_n), end='')
            train_log.write('Epoch:{0:d} acc:{1:.4f} loss:{2:.4f} mem:{3:.4f} rec:{4:.4f} '.format(epoch, sum(true_acc_list)/len(true_acc_list), sum(cls_loss_list)/len(cls_loss_list), sum(mem_list)/len(mem_list), sum(rec_list)/len(rec_list)))
            train_log.flush()  
            dataset_noise = (train_datasets.targets_true.argmax(dim=1) != train_datasets.targets_seq_a.argmax(dim=1)).float().mean()
            print('\ndataset noise:', dataset_noise, 'RT:', RT)

            model_a.eval()

            # gradient bias
            if False:#epoch in [50, 100, 150]:
                delta_g0 = 0
                delta_g1 = 0
                delta_g2 = 0
                noisy_num = 0

                for batch_idx, data in enumerate(train_loader):
                    input_layers = ([model_a.layer4[0].conv1, model_a.layer4[0].conv2, model_a.layer4[1].conv1, model_a.layer4[1].conv2])
                    input_modules = [HookModule(model=model_a, module=layer) for layer in input_layers]

                    train_image, train_label_a, train_label_b, train_label_true, train_index = data
                    train_image = Variable(train_image.float(), requires_grad=True).cuda()
                    train_label = Variable(torch.eye(n_class)[train_label_a]).cuda()
                    train_label_true = Variable(train_label_true).cuda()

                    cls_logits = model_a(train_image)
                    cls_output = nn.Softmax(dim=-1)(cls_logits)
                    cls_output = (cls_output+1e-4)/(1+2e-4)
                    cons_mask = torch.zeros(train_label.shape[0]).bool().cuda()

                    cls_loss = nn.NLLLoss()(torch.log(cls_output)[~cons_mask], torch.argmax(train_label[~cons_mask], dim=1))
                    cls_loss_true = nn.NLLLoss()(torch.log(cls_output)[~cons_mask], torch.argmax(train_label_true[~cons_mask], dim=1))
                    noisy_mask = (train_label.argmax(dim=1) != train_label_true.argmax(dim=1))
                    
                    total_feature_grads = [torch.stack([torch.autograd.grad(outputs=cls_logits[:,i].sum(), inputs=m.activations, retain_graph=True, create_graph=False)[0].cpu() for i in range(n_class)]) for m in input_modules]
                    rand_epochs = 10
                    rand_label_list = np.array([[np.eye(n_class)[train_datasets.__getitem__(i)[1]] for r in range(rand_epochs)] for i in train_index]).sum(axis=1).transpose(1,0) / rand_epochs
                    delta_g0 += sum([(g * train_label.cpu().transpose(0,1).reshape(n_class,-1, 1,1,1) - g * train_label_true.cpu().transpose(0,1).reshape(n_class,-1, 1,1,1))\
                        [:,noisy_mask].abs().sum() for g in total_feature_grads])
                    delta_g1 += sum([(g * torch.tensor(rand_label_list).reshape(n_class,-1, 1,1,1) - g * train_label_true.cpu().transpose(0,1).reshape(n_class,-1, 1,1,1))\
                        [:,noisy_mask].abs().sum() for g in total_feature_grads])
                    ly2z = torch.autograd.grad(outputs=cls_loss_true, inputs=cls_logits, retain_graph=True, create_graph=False)[0].transpose(0,1).reshape(n_class, -1, 1, 1, 1) * cls_logits.shape[0]
                    delta_g2 += sum([(g * ly2z.cpu())[:,noisy_mask].abs().sum() for g in total_feature_grads])
                    noisy_num += (noisy_mask).sum()    
                print(delta_g0/noisy_num, delta_g1/noisy_num, delta_g2/noisy_num)
            # g&a comparison
            input_layers = ([model_a.layer4[0].conv1, model_a.layer4[0].conv2, model_a.layer4[1].conv1, model_a.layer4[1].conv2])
            input_modules = [HookModule(model=model_a, module=layer) for layer in input_layers]

            train_image = Variable(torch.stack(([transform_test(img) for img in train_datasets.data[:10]])), requires_grad=True).cuda()
            cls_logits = model_a(train_image)
            total_feature_grads = [torch.stack([torch.autograd.grad(outputs=cls_logits[:,i].sum(), inputs=m.activations, retain_graph=True, create_graph=False)[0].cpu() for i in range(n_class)]) for m in input_modules]
            total_activations = [m.activations.detach() for m in input_modules_a]
            for i in range(4):
                all_g[i][epoch] = total_feature_grads[i].detach()
                all_a[i][epoch] = total_activations[i].detach()


            # validation
            print('fea test:', epoch)
            model_a.eval()
            model_b.eval()
            cls_loss_list = []
            cls_acc_list = []
            total_label = torch.zeros(valid_datasets.__len__()).long()
            total_pred = torch.zeros(valid_datasets.__len__()).long()
            total_pred_a = torch.zeros(valid_datasets.__len__()).long()
            total_pred_b = torch.zeros(valid_datasets.__len__()).long()
            total_map_pred = torch.zeros(valid_datasets.__len__()).long()

            for batch_idx, data in enumerate(valid_loader):
                valid_image, _, valid_label = data
                valid_image = Variable(valid_image.float(), requires_grad=True).cuda()
                valid_label = Variable(valid_label).cuda()

                output_a = model_a(valid_image)
                output_a = nn.Softmax(dim=1)(output_a)
                output_b = model_b(valid_image)
                output_b = nn.Softmax(dim=1)(output_b)
                cls_output = (output_a + output_b)/2
                cls_loss = nn.NLLLoss()(torch.log(cls_output), torch.argmax(valid_label, dim=1))

                a_pred_a = a_pred_b = torch.zeros(valid_label.shape[0]).long().cuda()
                cls_acc = (torch.argmax(cls_output, dim=1) == torch.argmax(valid_label, dim=1)).float().mean()
                total_label[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(valid_label, dim=1).detach().cpu()
                total_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(cls_output, dim=1).detach().cpu()
                total_pred_a[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(output_a, dim=1).detach().cpu()
                total_pred_b[batch_idx * batch_size:(batch_idx + 1) * batch_size] = torch.argmax(output_b, dim=1).detach().cpu()
                total_map_pred[batch_idx * batch_size:(batch_idx + 1) * batch_size] = ((a_pred_a+a_pred_b)/2).detach().cpu()

                cls_loss_list.append(cls_loss.item())
                cls_acc_list.append(cls_acc.item())

                print('\rbatch:{0}/{1} cls loss:{2:.4f} acc:{3:.4f} A:{4:.4f} B:{5:.4f}'.format(
                        batch_idx, valid_datasets.__len__() // batch_size, 
                        sum(cls_loss_list) / len(cls_loss_list), 
                        (total_label == total_pred).float().mean(),
                        (total_label == total_pred_a).float().mean(),
                        (total_label == total_pred_b).float().mean(),
                        ), end='')
            if (total_label == total_pred_a).float().mean() > constrain_threshold and (total_label == total_pred_b).float().mean() > constrain_threshold:
                constrain_flag = True
            print()
            train_log.write('Eval acc:{0:.4f}\n'.format((total_label == total_pred).float().mean()))
            train_log.flush()  

            total_valid_acc_list.append((total_label == total_pred).float().mean())
            if epoch == 0 or total_valid_acc_list[-1] == max(total_valid_acc_list):
                torch.save(model_a.state_dict(), checkpoint_path_a)
                torch.save(model_b.state_dict(), checkpoint_path_b)
                print('model update')
            if epoch > 40 and total_valid_acc_list[-20] == max(total_valid_acc_list[-20:]):
                #break
                None
            scheduler_a.step()
            scheduler_b.step()
            #scheduler_a.step((total_label == total_pred_a).float().mean())
            #scheduler_b.step((total_label == total_pred_b).float().mean())