from __future__ import print_function
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
from types import new_class
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import argparse
import numpy as np
from sklearn.mixture import GaussianMixture

from DivideMix.PreResNet import *
import DivideMix.dataloader_cifar as dataloader
import models
from utils import *

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--noise_mode',  default='sym')
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=25, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--r', default=0.4, type=float, help='noise ratio')
parser.add_argument('--id', default='')
parser.add_argument('--num_class', default=10, type=int)
parser.add_argument('--data_path', default='./cifar-10', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar10', type=str)
args = parser.parse_args()

n_class = args.num_class
train_idx = 0
data_type = 'cifar-10' if args.dataset == 'cifar10' else ''
noise_ratio = int(args.r * 100)
while os.path.exists('log/{0:s}_noise-{1:d}_{2:s}_idx{3:d}_log.txt'.format(data_type, noise_ratio, 'dividemix', train_idx)):
    train_idx += 1
train_log=open('log/{0:s}_noise-{1:d}_{2:s}_idx{3:d}_log.txt'.format(data_type, noise_ratio, 'dividemix', train_idx),'w')

# Training
def train(epoch,net_idx,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,loader, E=30):
    net.train()
    net2.eval() #fix one network and train the other
    
    unlabeled_train_iter = iter(unlabeled_trainloader)    
    num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
    correct = correct_x = correct_u = correct_lx = correct_lu = 0
    total = total_x = total_u = total_lx = total_lu = 0
    for batch_idx, (inputs_x, inputs_x2, labels_x, labels_x2, labels_true_x, w_x, index_x) in enumerate(labeled_trainloader):      
        try:
            inputs_u, inputs_u2, labels_u, labels_u2, labels_true_u, index_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2, labels_u, labels_u2, labels_true_u, index_u = unlabeled_train_iter.next()                 
        batch_size = inputs_x.size(0)
        batch_size_u = inputs_u.size(0)
                
        # Transform label to one-hot
        if net_idx == 1:
            labels_x = labels_x2.clone()
            labels_u = labels_u2.clone()
        labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)
        labels_u = torch.zeros(batch_size_u, args.num_class).scatter_(1, labels_u.view(-1,1), 1)
        labels_true_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_true_x.view(-1,1), 1)
        labels_true_u = torch.zeros(batch_size_u, args.num_class).scatter_(1, labels_true_u.view(-1,1), 1)
        w_x = w_x.view(-1,1).type(torch.FloatTensor) 
        w_u = 0.2 if epoch >= E else 0.0

        inputs_x, inputs_x2, labels_x, labels_true_x, w_x, index_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), labels_true_x.cuda(), w_x.cuda(), index_x.cuda()
        inputs_u, inputs_u2, labels_u, labels_true_u = inputs_u.cuda(), inputs_u2.cuda(), labels_u.cuda(), labels_true_u.cuda()

        with torch.no_grad():
            # label co-guessing of unlabeled samples
            outputs_u11 = net(inputs_u)
            outputs_u12 = net(inputs_u2)
            outputs_u21 = net2(inputs_u)
            outputs_u22 = net2(inputs_u2)            
            
            pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4  
            pu = w_u*labels_u + (1-w_u)*pu    
            ptu = pu**(1/args.T) # temparature sharpening
            
            targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
            targets_u = targets_u.detach()       
            
            # label refinement of labeled samples
            outputs_x = net(inputs_x)
            outputs_x2 = net(inputs_x2)            
            
            px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
            px = w_x*labels_x + (1-w_x)*px 
            ptx = px**(1/args.T) # temparature sharpening 
                       
            targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
            targets_x = targets_x.detach()       
        
        # mixmatch
        l = np.random.beta(args.alpha, args.alpha)        
        l = max(l, 1-l)
                
        all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)
        all_targets_true = torch.cat([labels_true_x, labels_true_x, labels_true_u, labels_true_u], dim=0)
        
        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        
        mixed_input = l * input_a + (1 - l) * input_b        
        mixed_target = l * target_a + (1 - l) * target_b
        
        logits = net(mixed_input)
        logits_x = logits[:batch_size*2]
        logits_u = logits[batch_size*2:]     
        prediction = nn.Softmax(dim=1)(logits)   
           
        Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size*2], logits_u, mixed_target[batch_size*2:], epoch+batch_idx/num_iter, warm_up)

        # label update
        if epoch >= E:
            a_pred_x = outputs_x.argmax(dim=1)
            a_pred_x2 = outputs_x2.argmax(dim=1)
            rand_class = torch.randint(0, n_class, (index_x.shape[0],))
            loader.update_labels(1-net_idx, index_x, loader.targets_noisy[index_x])
            loader.update_labels(1-net_idx, index_x, torch.eye(n_class)[a_pred_x] / 2 * min((max(0, epoch-E)) / 10, 0.1))
            loader.update_labels(1-net_idx, index_x, torch.eye(n_class)[a_pred_x2] / 2 * min((max(0, epoch-E)) / 10, 0.1))
            #loader.update_labels(1-net_idx, index_x, torch.eye(n_class)[a_pred_x] / 2 * (1 - max(0, epoch-E)/(args.num_epochs-E) * 0.5))
            #loader.update_labels(1-net_idx, index_x, torch.eye(n_class)[a_pred_x2] / 2 * (1 - max(0, epoch-E)/(args.num_epochs-E) * 0.5))
            loader.update_labels(1-net_idx, index_x, torch.ones(index_x.shape[0], args.num_class).float() / args.num_class * max(0, epoch-E) / ((args.num_epochs-E)) * 0.1)
            a_pred_u = outputs_u11.argmax(dim=1)
            a_pred_u2 = outputs_u12.argmax(dim=1)
            rand_class = torch.randint(0, n_class, (index_u.shape[0],))
            loader.update_labels(1-net_idx, index_u, loader.targets_noisy[index_u] * (1 - max(0, epoch-E)/(args.num_epochs-E)))
            loader.update_labels(1-net_idx, index_u, torch.eye(n_class)[a_pred_u] / 2 * min((max(0, epoch-E)) / 30, 0.2))
            loader.update_labels(1-net_idx, index_u, torch.eye(n_class)[a_pred_u2] / 2 * min((max(0, epoch-E)) / 30, 0.2))
            #loader.update_labels(1-net_idx, index_u, torch.eye(n_class)[a_pred_u] / 2 * (1 - max(0, epoch-E)/(args.num_epochs-E) * 0.5))
            #loader.update_labels(1-net_idx, index_u, torch.eye(n_class)[a_pred_u2] / 2 * (1 - max(0, epoch-E)/(args.num_epochs-E) * 0.5))
            loader.update_labels(1-net_idx, index_u, torch.ones(index_u.shape[0], args.num_class).float() / args.num_class * max(0, epoch-E) / ((args.num_epochs-E)) * 0.5)

        # regularization
        prior = torch.ones(args.num_class)/args.num_class
        prior = prior.cuda()
        pred_mean = torch.softmax(logits, dim=1).mean(0)
        penalty = torch.sum(prior*torch.log(prior/pred_mean))

        loss = Lx + lamb * Lu  + penalty
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total += mixed_target.size(0)
        correct += prediction.argmax(dim=1).eq(all_targets_true.argmax(dim=1)).cpu().sum().item()
        total_x += targets_x.size(0) * 2
        correct_x += logits_x.argmax(dim=1).eq(torch.cat([labels_true_x, labels_true_x], dim=0).argmax(dim=1)).cpu().sum().item()
        total_u += targets_u.size(0) * 2
        correct_u += logits_u.argmax(dim=1).eq(torch.cat([labels_true_u, labels_true_u], dim=0).argmax(dim=1)).cpu().sum().item()
        total_lx += labels_x.size(0)
        correct_lx += labels_x.argmax(dim=1).eq(labels_true_x.argmax(dim=1)).cpu().sum().item()
        total_lu += labels_u.size(0)
        correct_lu += labels_u.argmax(dim=1).eq(labels_true_u.argmax(dim=1)).cpu().sum().item()
    acc = 100.*correct/total
    acc_x = 100.*correct_x/total_x
    acc_u = 100.*correct_u/total_u
    acc_lx = 100.*correct_lx/total_lx
    acc_lu = 100.*correct_lu/total_lu
    print("| Train Epoch #%d\t Accuracy: %.2f%% (%.2f%%, %.2f%%)\t Label Acc: (%.2f%%, %.2f%%)" %(epoch, acc, acc_x, acc_u, acc_lx, acc_lu)) 
    train_log.write('Epoch:{0:d} acc:{1:.4f} loss:{2:.4f} '.format(epoch, acc/100, loss))
    train_log.flush()  


def warmup(epoch,net,optimizer,dataloader):
    net.train()
    num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
    correct = 0
    total = 0
    for batch_idx, (inputs, labels, index) in enumerate(dataloader):      
        inputs, labels = inputs.cuda(), labels.cuda() 
        outputs = net(inputs)    
        _, predicted = torch.max(outputs, 1)                
        loss = CEloss(outputs, labels)      
        if args.noise_mode=='asym':  # penalize confident prediction for asymmetric noise
            penalty = conf_penalty(outputs)
            L = loss + penalty      
        elif args.noise_mode=='sym':   
            L = loss

        optimizer.zero_grad()
        L.backward()  
        optimizer.step() 

        total += labels.size(0)
        correct += predicted.eq(labels).cpu().sum().item()     
    acc = 100.*correct/total
    print("| Warmup Epoch #%d\t Accuracy: %.2f%%" %(epoch,acc)) 
    train_log.write('Epoch:{0:d} acc:{1:.4f} loss:{2:.4f} '.format(epoch, acc/100, loss))
    train_log.flush()  

def test(epoch,net1,net2):
    net1.eval()
    net2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1 = net1(inputs)
            outputs2 = net2(inputs)           
            outputs = outputs1+outputs2
            _, predicted = torch.max(outputs, 1)            
                       
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()                 
    acc = 100.*correct/total
    print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" %(epoch,acc))  
    train_log.write('Eval acc:{0:.4f}\n'.format(acc/100))
    train_log.flush()  

def eval_train(model,all_loss):    
    model.eval()
    losses = torch.zeros(45000)    
    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
            inputs, targets = inputs.cuda(), targets.cuda() 
            outputs = model(inputs) 
            loss = CE(outputs, targets)  
            for b in range(inputs.size(0)):
                losses[index[b]]=loss[b]         
    losses = (losses-losses.min())/(losses.max()-losses.min())    
    all_loss.append(losses)

    if args.r==0.9: # average loss over last 5 epochs to improve convergence stability
        history = torch.stack(all_loss)
        input_loss = history[-5:].mean(0)
        input_loss = input_loss.reshape(-1,1)
    else:
        input_loss = losses.reshape(-1,1)

    # fit a two-component GMM to the loss
    gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss) 
    prob = prob[:,gmm.means_.argmin()]         
    return prob,all_loss

def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u*float(current)

class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, linear_rampup(epoch,warm_up)

class NegEntropy(object):
    def __call__(self,outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log()*probs, dim=1))

def create_model():
    #model = models.load_model(model_name='vgg16', in_channels=3, num_classes=args.num_class)
    model = ResNet18(num_classes=args.num_class)
    model = model.cuda()
    return model

if args.dataset=='cifar10':
    warm_up = 10
elif args.dataset=='cifar100':
    warm_up = 30

loader = dataloader.cifar_dataloader(args.dataset,n_class=args.num_class,r=args.r,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
    root_dir=args.data_path,log=train_log,noise_file='%s/%.1f_%s.json'%(args.data_path,args.r,args.noise_mode))

print('| Building net')
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True

criterion = SemiLoss()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler1 = lr_scheduler.CosineAnnealingLR(optimizer1, T_max=200)
scheduler2 = lr_scheduler.CosineAnnealingLR(optimizer2, T_max=200)

CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
if args.noise_mode=='asym':
    conf_penalty = NegEntropy()

all_loss = [[],[]] # save the history of losses from two networks

for epoch in range(args.num_epochs+1):   
    lr=args.lr
    if epoch >= 150:
        lr /= 10      
    for param_group in optimizer1.param_groups:
        param_group['lr'] = lr       
    for param_group in optimizer2.param_groups:
        param_group['lr'] = lr          
    test_loader = loader.run('test')
    eval_loader = loader.run('eval_train')   
    
    if epoch<warm_up:       
        warmup_trainloader = loader.run('warmup')
        print('Warmup Net1')
        warmup(epoch,net1,optimizer1,warmup_trainloader)    
        print('\nWarmup Net2')
        warmup(epoch,net2,optimizer2,warmup_trainloader) 
   
    else:         
        prob1,all_loss[0]=eval_train(net1,all_loss[0])   
        prob2,all_loss[1]=eval_train(net2,all_loss[1])
               
        pred1 = (prob1 > args.p_threshold)      
        pred2 = (prob2 > args.p_threshold)    
        
        print('Train Net1')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
        train(epoch,0,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader, loader) # train net1  
        
        print('\nTrain Net2')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
        train(epoch,1,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader, loader) # train net2         

    test(epoch,net1,net2)  
    scheduler1.step()
    scheduler2.step()


