import argparse
import logging
import math
import os
import random
import shutil
import time

import numpy as np
import torch
import json
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from torch.cuda import amp
from tqdm import tqdm

from src.AuxMix.dataset.cifar import DATASET_GETTERS, collate_fn
from src.AuxMix.dataset.cifar import show_imgs
from src.AuxMix.utils import AverageMeter, accuracy
from src.AuxMix.utils.learning_rate_utils import WarmupMultiStepLR, WarmupStepLR, WarmupCosineLR

from src.AuxMix.models import wideresnet
from src.AuxMix.models import resnet
from src.AuxMix.models.ema import ModelEMA


def save_checkpoint(state, is_best, checkpoint_path, key_string=None):
    if key_string is None:
        filename = 'checkpoint.pth.tar'
    else:
        filename = '{}_checkpoint.pth.tar'.format(key_string)
    filepath = os.path.join(checkpoint_path, filename)
    torch.save(state, filepath, _use_new_zipfile_serialization=False)
    if is_best:
        if key_string is None:
            filename = 'model_best.pth.tar'
        else:
            filename = '{}_model_best.pth.tar'.format(key_string)
        shutil.copyfile(filepath, os.path.join(checkpoint_path, filename))


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])


def de_interleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])


def create_model(args, rot_model=False):
    if args.arch == 'wideresnet':
        model = wideresnet.build_model(depth=args.model_depth,
                                       widen_factor=args.model_width,
                                       dropout=0,
                                       num_classes=args.num_classes,
                                       rot_model=rot_model)
    elif args.arch == 'resnet50':
        model = resnet.build_resnet50(num_classes=args.num_classes,
                                      rot_model=rot_model,
                                      cifar10_model=False)

    elif args.arch == 'cifar10-resnet50':
        model = resnet.build_resnet50(num_classes=args.num_classes,
                                      rot_model=rot_model,
                                      cifar10_model=True)

    logger.info("Total params: {:f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))

    if args.summary:
        from torchinfo import summary
        summary(model, (1, 3, 32, 32))

    return model


def get_scheduler(optimizer, total_steps, lrs_milestones=None, lrs_step_size=None,
                  lrs_gamma=None, warmup=None, scheduler_type='cosine'):
    if scheduler_type == 'multistep':
        scheduler = WarmupMultiStepLR(optimizer, milestones=lrs_milestones, gamma=lrs_gamma,
                                      warmup_factor=1/5, warmup_iters=warmup)
    elif scheduler_type == 'step':
        scheduler = WarmupStepLR(optimizer, step_size=lrs_step_size, gamma=lrs_gamma,
                                 warmup_factor=1/5, warmup_iters=warmup)
    elif scheduler_type == 'cosine':
        scheduler = WarmupCosineLR(optimizer, args.warmup, total_steps)
    else:
        raise ValueError("LRS is invalid")

    return scheduler


def setup_model(args, model, total_steps, lr, use_ema=True, scheduler_type='cosine',
                milestones=None, step_size=None, gamma=None):
    model.to(args.device)
    no_decay = ['bias', 'bn']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = optim.SGD(grouped_parameters, lr=lr,
                          momentum=0.9, nesterov=args.nesterov)

    args.epochs = math.ceil(total_steps / args.eval_step)
    scheduler = get_scheduler(optimizer=optimizer, total_steps=total_steps,
                              lrs_milestones=[milestone*args.eval_step for milestone in milestones],
                              lrs_step_size=step_size, lrs_gamma=gamma, warmup=args.warmup,
                              scheduler_type=scheduler_type)

    ema_model = None
    if use_ema:
        ema_model = ModelEMA(args, model, args.ema_decay)

    args.start_epoch = 0
    best_acc = 0.0

    if args.resume:
        logger.info("==> Resuming from checkpoint..")
        assert os.path.isfile(
            args.resume), "Error: no checkpoint directory found!"
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        if use_ema:
            ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    if args.amp:
        from apex import amp
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.opt_level)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank, find_unused_parameters=True)

    return model, ema_model, optimizer, scheduler, best_acc


def create_dataloaders(args, do_rotations=False, return_index=False):
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    labeled_dataset, unlabeled_dataset, aux_dataset, test_dataset = DATASET_GETTERS[args.dataset](
        args, args.root_datapath, do_rotations=do_rotations, return_index=return_index)

    logger.info("Len labeled dataset: {}".format(len(labeled_dataset)))
    logger.info("Len unlabeled dataset: {}".format(len(unlabeled_dataset)))
    logger.info("Len test dataset: {}".format(len(test_dataset)))

    if args.local_rank == 0:
        torch.distributed.barrier()

    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

    labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
        drop_last=True)

    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        sampler=train_sampler(unlabeled_dataset),
        batch_size=args.batch_size*args.mu,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
        drop_last=True)

    aux_trainloader = DataLoader(
        aux_dataset,
        sampler=train_sampler(aux_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
        drop_last=False)

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        collate_fn=collate_fn,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    return labeled_trainloader, unlabeled_trainloader, aux_trainloader, test_loader


def main(args):
    # Choose model
    if args.dataset == 'cifar10':
        args.num_classes = 10
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 2

    elif args.dataset == 'cifar100':
        args.num_classes = 100
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 8

    elif args.dataset == 'cifar10-animals-others':
        args.num_classes = 6
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 2


    if not args.use_offline_scores:

        labeled_trainloader, unlabeled_trainloader, aux_trainloader, test_loader = create_dataloaders(args,
                                                                                                      do_rotations=True)
        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()

        rot_model = create_model(args, rot_model=True)

        if args.local_rank == 0:
            torch.distributed.barrier()

        rot_model, rot_ema_model, optimizer, scheduler, best_acc = setup_model(args, rot_model,
                                                                            total_steps=args.ss_epochs*args.eval_step,
                                                                            lr=args.lr_rot,
                                                                            use_ema=args.use_ema_rot,
                                                                            scheduler_type=args.scheduler_type_rot,
                                                                            milestones=args.milestones_rot,
                                                                            step_size=args.step_size_rot,
                                                                            gamma=args.gamma_rot)

        # Regularize the domain adversarial model
        rot_model.zero_grad()

        logger.info("***** Running self-supervision *****")
        logger.info(f"  Task = {args.dataset}@{args.num_labeled}")
        logger.info(f"  Num self-supervise Epochs = {args.ss_epochs}")
        logger.info(f"  Batch size per GPU = {args.batch_size}")
        logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
        logger.info(f"  Total optimization steps = {args.ss_epochs * args.eval_step}")

        model, ema_model = self_supervise(args, labeled_trainloader, aux_trainloader, test_loader,
                                          rot_model, optimizer, rot_ema_model, scheduler, best_acc)

        # Delete rotation dataloaders
        del labeled_trainloader
        del unlabeled_trainloader
        del aux_trainloader
        del test_loader


        # Create new dataloaders without rotation
        labeled_trainloader, unlabeled_trainloader, aux_trainloader, test_loader = create_dataloaders(args,
                                                                                                      do_rotations=False,
                                                                                                      return_index=True)
        if args.use_ema_rot:
            model = ema_model.ema

        logger.info("***** Generating Aux Scores  *****")
        aux_scores = score_auxiliary_data(args, model, labeled_trainloader, aux_trainloader)
        del model
        del ema_model

        if args.ss_only:
            exit(0)

    else:
        # Generate aux scores from pretained weights
        # Create new dataloaders without rotation
        labeled_trainloader, unlabeled_trainloader, aux_trainloader, test_loader = create_dataloaders(args,
                                                                                                      do_rotations=False,
                                                                                                      return_index=True)

        if args.local_rank not in [-1, 0]:
            torch.distributed.barrier()

        rot_model = create_model(args, rot_model=True)

        if args.local_rank == 0:
            torch.distributed.barrier()

        rot_model, _, _, _, _ = setup_model(args, rot_model,
                                            total_steps=args.ss_epochs * args.eval_step,
                                            lr=args.lr_rot,
                                            use_ema=args.use_ema_rot,
                                            scheduler_type=args.scheduler_type_rot,
                                            milestones=args.milestones_rot,
                                            step_size=args.step_size_rot,
                                            gamma=args.gamma_rot)

        # Load new model with weights from self-supervision task
        checkpoint_filename = os.path.join(args.out, 'ss_model_best.pth.tar')
        logger.info(f"  Loading checkpoint {checkpoint_filename}")
        pretrained_dict = torch.load(checkpoint_filename)['state_dict']
        if args.use_ema_rot:
            pretrained_dict = torch.load(checkpoint_filename)['ema_state_dict']
        rot_model.load_state_dict(pretrained_dict, strict=True)

        logger.info("***** Generating Aux Scores  *****")
        aux_scores = score_auxiliary_data(args, rot_model, labeled_trainloader, aux_trainloader)
        del rot_model


    # Create non-self-supervision model
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    model = create_model(args, rot_model=False)
    if args.local_rank == 0:
        torch.distributed.barrier()

    model, ema_model, optimizer, scheduler, best_acc = setup_model(args, model, total_steps=args.total_steps,
                                                                   lr=args.lr,
                                                                   scheduler_type=args.scheduler_type,
                                                                   use_ema=args.use_ema,
                                                                   milestones=args.milestones,
                                                                   step_size=args.step_size,
                                                                   gamma=args.gamma)

    # Load new model with weights from self-supervision task
    checkpoint_filename = os.path.join(args.out, 'ss_model_best.pth.tar')
    logger.info(f"  Loading checkpoint {checkpoint_filename}")
    pretrained_dict = torch.load(checkpoint_filename)['state_dict']
    if args.use_ema_rot:
        pretrained_dict = torch.load(checkpoint_filename)['ema_state_dict']
    # Remove the fc layer
    del pretrained_dict['fc.weight']
    del pretrained_dict['fc.bias']
    model.load_state_dict(pretrained_dict, strict=False)
    if args.use_ema:
        ema_model.ema.load_state_dict(pretrained_dict, strict=False)

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.num_labeled}")
    logger.info(f"  Num Epochs = {args.epochs}")
    logger.info(f"  Batch size per GPU = {args.batch_size}")
    logger.info(f"  Total train batch size = {args.batch_size * args.world_size}")
    logger.info(f"  Total optimization steps = {args.total_steps}")

    model.zero_grad()
    train_auxmix(args, labeled_trainloader, unlabeled_trainloader, test_loader,
                   model, optimizer, ema_model, scheduler, aux_scores, best_acc)


def self_supervise(args, labeled_trainloader, aux_trainloader, test_loader,
                   model, optimizer, ema_model, scheduler, best_acc=0.0):
    if args.amp:
        from apex import amp

    best_acc = best_acc
    test_accs = []
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    rot_top1 = AverageMeter()
    end = time.time()

    if args.world_size > 1:
        labeled_epoch = 0
        aux_epoch = 0
        labeled_trainloader.sampler.set_epoch(labeled_epoch)
        aux_trainloader.sampler.set_epoch(aux_epoch)

    labeled_iter = iter(labeled_trainloader)
    aux_iter = iter(aux_trainloader)

    for epoch in range(args.start_epoch, args.ss_epochs):
        model.train()
        if not args.no_progress:
            p_bar = tqdm(range(args.eval_step),
                         disable=args.local_rank not in [-1, 0])
        for batch_idx in range(args.eval_step):
            try:
                inputs_x, _, rot_targets_x, _ = labeled_iter.next()
            except:
                if args.world_size > 1:
                    labeled_epoch += 1
                    labeled_trainloader.sampler.set_epoch(labeled_epoch)
                labeled_iter = iter(labeled_trainloader)
                inputs_x, _, rot_targets_x, _ = labeled_iter.next()

            try:
                inputs_a, _, rot_targets_a, _ = aux_iter.next()
            except:
                if args.world_size > 1:
                    aux_epoch += 1
                    aux_trainloader.sampler.set_epoch(aux_epoch)
                aux_iter = iter(aux_trainloader)
                inputs_a, _, rot_targets_a, _ = aux_iter.next()

            data_time.update(time.time() - end)
            batch_size = inputs_x.shape[0]
            inputs = interleave(torch.cat((inputs_x, inputs_a)), 2).to(args.device)
            targets = torch.cat((rot_targets_x, rot_targets_a)).to(args.device)

            logits = model(inputs)
            logits = de_interleave(logits, 2)

            loss = F.cross_entropy(logits, targets, reduction='mean')
            acc = accuracy(logits, targets, topk=(1,))[0]
            rot_top1.update(acc.item())

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            losses.update(loss.item())

            optimizer.step()
            scheduler.step()
            if args.use_ema_rot:
                ema_model.update(model)
            model.zero_grad()

            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                p_bar.set_description(
                    "Epoch:{epoch}/{epochs:4} LR:{lr:.4f} Data:{data:.3f}s "
                    "Batch:{bt:.3f}s Loss:{loss:.4f} rot_top1:{rot_top1:.4f}".format(
                        epoch=epoch + 1,
                        epochs=args.ss_epochs,
                        batch=batch_idx + 1,
                        iter=args.eval_step,
                        lr=scheduler.get_last_lr()[0],
                        data=data_time.avg,
                        bt=batch_time.avg,
                        loss=losses.avg,
                        rot_top1=rot_top1.avg))
                p_bar.update()

        if not args.no_progress:
            p_bar.close()

        if args.no_progress:
            logger.info("Epoch:{epoch}/{epochs:4} LR:{lr:.4f} Data:{data:.3f}s "
                "Batch:{bt:.3f}s Loss:{loss:.4f} rot_top1:{rot_top1:.4f}".format(
                    epoch=epoch + 1,
                    epochs=args.ss_epochs,
                    batch=batch_idx + 1,
                    iter=args.eval_step,
                    lr=scheduler.get_last_lr()[0],
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    rot_top1=rot_top1.avg))


        if args.use_ema_rot:
            test_model = ema_model.ema
        else:
            test_model = model

        if args.local_rank in [-1, 0]:
            logger.info('Epoch:{}/{:4}'.format(epoch + 1, args.ss_epochs))
            test_loss, test_acc = test_rot(args, test_loader, test_model)

            args.writer.add_scalar('train/1.train_loss_ss', losses.avg, epoch)
            args.writer.add_scalar('train/2.train_acc_ss', rot_top1.avg, epoch)
            args.writer.add_scalar('test/1.test_acc_ss', test_acc, epoch)
            args.writer.add_scalar('test/2.test_loss_ss', test_loss, epoch)

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)

            model_to_save = model.module if hasattr(model, "module") else model
            if args.use_ema_rot:
                ema_to_save = ema_model.ema.module if hasattr(
                    ema_model.ema, "module") else ema_model.ema
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_to_save.state_dict(),
                'ema_state_dict': ema_to_save.state_dict() if args.use_ema_rot else None,
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, is_best, args.out, key_string='ss')

            test_accs.append(test_acc)
            logger.info('Best top-1 acc: {:.2f}'.format(best_acc))
            logger.info('Mean top-1 acc: {:.2f}\n'.format(
                np.mean(test_accs[-20:])))

    if args.local_rank in [-1, 0]:
        args.writer.close()

    checkpoint_filename = os.path.join(args.out, 'ss_model_best.pth.tar')
    logger.info(f"  Loading checkpoint {checkpoint_filename}")
    best_checkpoint = torch.load(checkpoint_filename)
    model.load_state_dict(best_checkpoint['state_dict'], strict=True)
    if args.use_ema_rot:
        ema_model.ema.load_state_dict(best_checkpoint['ema_state_dict'], strict=True)

    return model, ema_model


def score_auxiliary_data(args, model, labeled_trainloader, aux_trainloader):
    model.eval()
    # Get prototypes for labelled data
    prototypes = []
    for cls in range(args.num_classes):
        prototypes.append(AverageMeter())

    if not args.no_progress:
        p_bar = tqdm(range(len(labeled_trainloader)),
                     disable=args.local_rank not in [-1, 0])

    # Get prototypes
    logger.info("Generating prototypes")
    for batch_idx, (inputs_x, targets_x) in enumerate(labeled_trainloader):
        inputs = inputs_x.to(args.device)
        targets = targets_x.to(args.device)

        _, features = model(inputs, get_features=True)

        for i in range(len(features)):
            prototypes[targets[i].item()].update(features[i].detach())

        if not args.no_progress:
            p_bar.set_description(
                "Iter: {batch:4}/{iter:4}".format(
                 batch=batch_idx + 1,
                 iter=len(labeled_trainloader)))
            p_bar.update()

    if not args.no_progress:
        p_bar.close()

    for i in range(len(prototypes)):
            prototypes[i] = prototypes[i].avg
    prototypes = torch.stack(prototypes, dim=0)

    max_sim_dict = {}
    sim_class_dict = {}
    if not args.no_progress:
        p_bar = tqdm(range(len(aux_trainloader)),
                     disable=args.local_rank not in [-1, 0])

    # Get aux scores
    logger.info("Generating aux scores")
    for batch_idx, (inputs_a, _, dict_indices) in enumerate(aux_trainloader):
        inputs = inputs_a.to(args.device)
        _, features = model(inputs, get_features=True)

        for feature, dict_index in zip(features, dict_indices):
            max_sim, sim_class = torch.max(F.cosine_similarity(feature, prototypes, dim=-1), dim=0)
            max_sim_dict[dict_index.item()] = max_sim.item()
            sim_class_dict[dict_index.item()] = sim_class.item()

        if not args.no_progress:
            p_bar.set_description(
                "Iter: {batch:4}/{iter:4}".format(
                 batch=batch_idx + 1,
                 iter=len(aux_trainloader)))
            p_bar.update()

    if not args.no_progress:
        p_bar.close()

    # Save dicts to file as json
    filename = os.path.join(args.out, 'aux_sim_scores.json')
    with open(filename, 'w') as f:
        f.write(json.dumps(max_sim_dict))

    filename = os.path.join(args.out, 'aux_sim_classes.json')
    with open(filename, 'w') as f:
        f.write(json.dumps(sim_class_dict))

    # Print some statistics to file
    scores = torch.Tensor(list(max_sim_dict.values()))
    logger.info('Aux scores mean: {:.4f}'.format(scores.mean()))
    logger.info('Aux scores median: {:.4f}'.format(scores.median()))
    logger.info('Aux scores max: {:.4f}'.format(scores.max()))
    logger.info('Aux scores min: {:.4f}'.format(scores.min()))
    logger.info('Aux scores std: {:.4f}'.format(scores.std()))

    return max_sim_dict


def train_auxmix(args, labeled_trainloader, unlabeled_trainloader, test_loader,
          model, optimizer, ema_model, scheduler, aux_scores_dict, best_acc=0.0):

    # Convert aux_scores to a tensor for vector access
    # Sort dictionary
    from collections import OrderedDict
    aux_scores = torch.Tensor(list(OrderedDict(sorted(aux_scores_dict.items())).values()))
    logger.info('Aux scores mean: {:.4f}'.format(aux_scores.mean()))
    logger.info('Aux scores median: {:.4f}'.format(aux_scores.median()))
    logger.info('Aux scores max: {:.4f}'.format(aux_scores.max()))
    logger.info('Aux scores min: {:.4f}'.format(aux_scores.min()))
    logger.info('Aux scores std: {:.4f}'.format(aux_scores.std()))

    if args.auto_thresh:
        mean = aux_scores.mean()
        std = aux_scores.std()
        args.threshold = mean + args.alpha_thresh * std
        # args.threshold = torch.trunc(args.threshold * 10 ** 2) / (10 ** 2)
        args.threshold = np.around(args.threshold, 2)
        logger.info("  Threshold {:.2f}".format(args.threshold))

    # The first args.num_labeled entries are meant for labeled readded as unlabeled
    aux_scores = torch.cat((torch.Tensor([1]*args.num_labeled), aux_scores), dim=0).to(args.device)

    if args.amp:
        from apex import amp

    best_acc = best_acc
    test_accs = []
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    losses_u_sel = AverageMeter()
    losses_u_rej = AverageMeter()
    mask_probs = AverageMeter()
    end = time.time()

    if args.world_size > 1:
        labeled_epoch = 0
        unlabeled_epoch = 0
        labeled_trainloader.sampler.set_epoch(labeled_epoch)
        unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
    
    labeled_iter = iter(labeled_trainloader)
    unlabeled_iter = iter(unlabeled_trainloader)

    for epoch in range(args.start_epoch, args.epochs):
        # unlabled_iter_count = 0
        model.train()

        if not args.no_progress:
            p_bar = tqdm(range(args.eval_step),
                         disable=args.local_rank not in [-1, 0])
        for batch_idx in range(args.eval_step):
            try:
                inputs_x, targets_x = labeled_iter.next()
            except:
                if args.world_size > 1:
                    labeled_epoch += 1
                    labeled_trainloader.sampler.set_epoch(labeled_epoch)
                labeled_iter = iter(labeled_trainloader)
                inputs_x, targets_x = labeled_iter.next()

            try:
                (inputs_u_w, inputs_u_s), _, indices = unlabeled_iter.next()
            except:
                if args.world_size > 1:
                    unlabeled_epoch += 1
                    unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
                unlabeled_iter = iter(unlabeled_trainloader)
                (inputs_u_w, inputs_u_s), _, indices = unlabeled_iter.next()

            data_time.update(time.time() - end)
            batch_size = inputs_x.shape[0]
            inputs = interleave(torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*args.mu+1).to(args.device)
            targets_x = targets_x.to(args.device)
            logits = model(inputs)
            logits = de_interleave(logits, 2*args.mu+1)
            logits_x = logits[:batch_size]
            logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
            del logits

            Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')

            pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1)
            if args.soft_targets:
                targets_u = pseudo_label
            else:
                _, targets_u = torch.max(pseudo_label, dim=-1)

            # Mask is calculated using aux scores
            mask = aux_scores[indices].ge(args.threshold).float()
            reject_mask = 1 - mask

            if not args.regularize:
                if args.soft_targets:
                    Lu = (F.kl_div(F.log_softmax(logits_u_s, dim=1), targets_u, reduction='none').sum(
                        dim=1) * mask).mean()
                else:
                    Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none') * mask).mean()
                Lu = args.lambda_u * Lu
            else:
                if args.soft_targets:
                    Lu_select = (F.kl_div(F.log_softmax(logits_u_s, dim=1), targets_u, reduction='none').sum(dim=1) * mask).mean()
                else:
                    Lu_select = (F.cross_entropy(logits_u_s, targets_u, reduction='none') * mask).mean()

                targets_a = (1. / args.num_classes) * torch.ones((inputs_u_w.shape[0], args.num_classes)).to(
                    args.device)
                Lu_reject_w = (F.kl_div(F.log_softmax(logits_u_w, dim=1), targets_a, reduction='none').sum(dim=1) * reject_mask)
                Lu_reject_s = (F.kl_div(F.log_softmax(logits_u_s, dim=1), targets_a, reduction='none').sum(dim=1) * reject_mask)

                if args.reg_augw:
                    Lu_reject = Lu_reject_w.mean()
                elif args.reg_augs:
                    Lu_reject = Lu_reject_s.mean()
                else:
                    Lu_reject = Lu_reject_w.mean() + Lu_reject_s.mean()

                Lu = args.lambda_u * Lu_select + args.lambda_reg * Lu_reject

                losses_u_sel.update(Lu_select.item())
                losses_u_rej.update(Lu_reject.item())

            loss = Lx + Lu

            if args.amp:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            losses.update(loss.item())
            losses_x.update(Lx.item())
            losses_u.update(Lu.item())

            optimizer.step()
            scheduler.step()
            if args.use_ema:
                ema_model.update(model)
            model.zero_grad()

            batch_time.update(time.time() - end)
            end = time.time()
            mask_probs.update(mask.mean().item())
            if not args.no_progress:
                p_bar.set_description("Train Epoch:{epoch}/{epochs:3} Iter:{batch:4}/{iter:4} "
                                      "LR:{lr:.4f} Data:{data:.3f}s Batch:{bt:.3f}s Loss:{loss:.4f} "
                                      "Loss_x:{loss_x:.4f} Loss_u:{loss_u:.4f} Lu_sel:{loss_u_sel:.4f} "
                                      "Lu_rej:{loss_u_rej:.4f} Mask:{mask:.2f}".format(
                    epoch=epoch + 1,
                    epochs=args.epochs,
                    batch=batch_idx + 1,
                    iter=args.eval_step,
                    lr=scheduler.get_last_lr()[0],
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    loss_x=losses_x.avg,
                    loss_u=losses_u.avg,
                    loss_u_sel=losses_u_sel.avg,
                    loss_u_rej=losses_u_rej.avg,
                    mask=mask_probs.avg))
                p_bar.update()

        if not args.no_progress:
            p_bar.close()

        # Epoch traning statistics to file
        logger.info("Train Epoch: {epoch}/{epochs:4} Iter: {batch:4}/{iter:4} "
                    "LR: {lr:.4f} Data: {data:.3f}s Batch: {bt:.3f}s Loss: {loss:.4f}. "
                    "Loss_x: {loss_x:.4f} Loss_u: {loss_u:.4f} Lu_sel: {loss_u_sel:.4f}"
                    "Lu_rej: {loss_u_rej:.4f} Mask: {mask:.2f}".format(
                    epoch=epoch + 1,
                    epochs=args.epochs,
                    batch=batch_idx + 1,
                    iter=args.eval_step,
                    lr=scheduler.get_last_lr()[0],
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    loss_x=losses_x.avg,
                    loss_u=losses_u.avg,
                    loss_u_sel=losses_u_sel.avg,
                    loss_u_rej=losses_u_rej.avg,
                    mask=mask_probs.avg))


        if args.use_ema:
            test_model = ema_model.ema
        else:
            test_model = model

        if args.local_rank in [-1, 0]:
            logger.info('Epoch:{}/{:4}'.format(epoch + 1, args.epochs))
            test_loss, test_acc = test(args, test_loader, test_model)

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)

            args.writer.add_scalar('train/1.train_loss', losses.avg, epoch)
            args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch)
            args.writer.add_scalar('train/3.train_loss_u', losses_u.avg, epoch)
            args.writer.add_scalar('train/4.mask', mask_probs.avg, epoch)
            args.writer.add_scalar('test/1.test_acc', test_acc, epoch)
            args.writer.add_scalar('test/2.test_loss', test_loss, epoch)
            args.writer.add_scalar('test/3.best_acc', best_acc, epoch)

            model_to_save = model.module if hasattr(model, "module") else model
            if args.use_ema:
                ema_to_save = ema_model.ema.module if hasattr(
                    ema_model.ema, "module") else ema_model.ema
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_to_save.state_dict(),
                'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None,
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, is_best, args.out)

            test_accs.append(test_acc)
            logger.info('Best top-1 acc: {:.2f}'.format(best_acc))
            logger.info('Mean top-1 acc: {:.2f}\n'.format(
                np.mean(test_accs[-20:])))

    if args.local_rank in [-1, 0]:
        args.writer.close()


def test_rot(args, test_loader, model):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    if not args.no_progress:
        test_loader = tqdm(test_loader, disable=args.local_rank not in [-1, 0])

    with torch.no_grad():
        for batch_idx, (inputs, _, targets, _) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            # targets = torch.zeros(inputs.shape[0]).type(torch.LongTensor).to(args.device)

            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            acc = accuracy(outputs, targets, topk=(1,))[0]

            losses.update(loss.item(), inputs.shape[0])
            top1.update(acc.item(), inputs.shape[0])

            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s."
                                            " Loss: {loss:.4f}. top1: {top1:.2f}".format(
                    batch=batch_idx + 1,
                    iter=len(test_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg))
        if not args.no_progress:
            test_loader.close()

    logger.info("top-1 acc: {:.2f}".format(top1.avg))
    return losses.avg, top1.avg


def test(args, test_loader, model):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    if not args.no_progress:
        test_loader = tqdm(test_loader,
                           disable=args.local_rank not in [-1, 0])

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)

            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                    batch=batch_idx + 1,
                    iter=len(test_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))
        if not args.no_progress:
            test_loader.close()

    logger.info("top-1 acc: {:.2f}".format(top1.avg))
    logger.info("top-5 acc: {:.2f}".format(top5.avg))
    return losses.avg, top1.avg


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch AuxMix Training')

    parser.add_argument('--data_url', type=str)
    parser.add_argument('--init_method', type=str)
    parser.add_argument('--train_url', type=str)

    parser.add_argument('--dataset', default='cifar10', type=str,
                        choices=['cifar10', 'cifar100', 'cifar10-animals-others'],
                        help='dataset name')
    parser.add_argument('--aux-dataset', default='tinyimagenet', type=str,
                        choices=['tinyimagenet', 'cub', 'noise'],
                        help='dataset name')
    parser.add_argument('--root-datapath', default='/mnt/my_files/aux-learning/data/cifar10/', type=str,
                        help='path to root data')
    parser.add_argument('--aux-datapath', default='/mnt/my_files/aux-learning/data/tiny_imagenet/train_master/', type=str,
                        help='path to aux data')
    parser.add_argument('--scores-path', default=None, type=str,
                        help='path to aux scores to use instead of default')
    parser.add_argument('--out', default='result',
                        help='directory to output the result')
    parser.add_argument('--arch', default='resnet50', type=str,
                        choices=['wideresnet', 'resnet50', 'cifar10-resnet50'],
                        help='dataset name')
    parser.add_argument('--summary', action='store_true', default=False,
                        help='print model summary')

    parser.add_argument('--gpu-id', default='0', type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num-workers', type=int, default=4,
                        help='number of workers')
    parser.add_argument('--num-labeled', type=int, default=4000,
                        help='number of labeled data')
    parser.add_argument("--expand-labels", action="store_true",
                        help="expand labels to fit eval steps")

    parser.add_argument('--total-steps', default=307200, type=int,
                        help='number of total steps to run')
    parser.add_argument('--eval-step', default=1024, type=int,
                        help='number of eval steps to run')
    parser.add_argument('--ss-epochs', default=100, type=int,
                        help='number of regularization epochs')
    parser.add_argument('--start-epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--batch-size', default=64, type=int,
                        help='train batchsize')
    parser.add_argument('--lr-rot', '--learning-rate-rot', default=0.1, type=float,
                        help='initial learning rate for rotation')
    parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                        help='initial learning rate')
    parser.add_argument('--warmup', default=0, type=float,
                        help='warmup epochs (unlabeled data based)')
    parser.add_argument('--wdecay', default=5e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', action='store_true', default=True,
                        help='use nesterov momentum')
    parser.add_argument('--scheduler-type-rot', default='cosine', type=str,
                        choices=['multistep', 'step', 'cosine'], help='scheduler for rotation task')
    parser.add_argument('--scheduler-type', default='multistep', type=str,
                        choices=['multistep', 'step', 'cosine'], help='scheduler for rotation task')

    # If using step scheduling
    parser.add_argument('--step-size-rot', default=10, type=float,
                        help='reduce lr every step-size epochs')
    parser.add_argument('--step-size', default=10, type=float,
                        help='reduce lr every step-size epochs')
    parser.add_argument('--milestones-rot', nargs='+', default=[30, 60, 80], type=int,
                        help='reduce lr every milestone epochs')
    parser.add_argument('--milestones', nargs='+', default=[100, 200, 250], type=int,
                        help='reduce lr every milestone epochs')
    parser.add_argument('--gamma-rot', default=0.1, type=float,
                        help='reduce lr by lr x lrs-gamma')
    parser.add_argument('--gamma', default=0.5, type=float,
                        help='reduce lr by lr x lrs-gamma')


    parser.add_argument('--use-ema', action='store_true', default=True,
                        help='use EMA model')
    parser.add_argument('--use-ema-rot', action='store_true',
                        help='use EMA model for rotation')
    parser.add_argument('--ema-decay', default=0.999, type=float,
                        help='EMA decay rate')
    parser.add_argument('--mu', default=7, type=int,
                        help='coefficient of unlabeled batch size')
    parser.add_argument('--lambda-u', default=1, type=float,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--lambda-reg', default=1, type=float,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--T', default=1, type=float,
                        help='pseudo label temperature')
    parser.add_argument('--threshold', default=0.95, type=float,
                        help='pseudo label threshold')

    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed', default=None, type=int,
                        help="random seed")
    parser.add_argument("--amp", action="store_true",
                        help="use 16-bit (mixed) precision through NVIDIA apex AMP")
    parser.add_argument("--opt_level", type=str, default="O1",
                        help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--no-progress', action='store_true',
                        help="don't use progress bar")

    parser.add_argument('--use-offline-scores', action='store_true', default=False,
                        help='use the similarity scores dict')
    parser.add_argument('--regularize', action='store_true', default=False,
                        help='whether to use regularization on reject mask')
    parser.add_argument('--ss-only', action='store_true', default=False,
                        help='only train rotation')
    parser.add_argument('--auto-thresh', action='store_true',
                        help='select thresholds based on mean and std')
    parser.add_argument('--alpha-thresh', default=-1, type=float,
                        help='alpha mulitplier on standard deviation')
    parser.add_argument('--reg-augw', action='store_true', default=False,
                        help='regularize weak augs of the unlabeled data')
    parser.add_argument('--reg-augs', action='store_true', default=False,
                        help='regularize strong augs of the unlabeled data')
    parser.add_argument('--soft-targets', action='store_true', default=False,
                        help='whether to use distrubtion or one hot targets for unlabeled data')
    args = parser.parse_args()

    # os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    # Setup device
    if args.local_rank == -1:
        device = torch.device('cuda', args.gpu_id)
        args.world_size = 1
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.world_size = torch.distributed.get_world_size()
        args.n_gpu = 1

    args.device = device

    # Set seed
    if args.seed is not None:
        set_seed(args)

    # Setup tensorboarding
    if args.local_rank in [-1, 0]:
        os.makedirs(args.out, exist_ok=True)
        args.writer = SummaryWriter(args.out)

    logger = logging.getLogger(__name__)
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s -   %(message)s")
    handler = logging.FileHandler(os.path.join(args.out, 'logs.txt'))
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    logger.warning(
        f"Process rank: {args.local_rank}, "
        f"device: {args.device}, "
        f"n_gpu: {args.n_gpu}, "
        f"distributed training: {bool(args.local_rank != -1)}, "
        f"16-bits training: {args.amp}", )

    logger.info(dict(args._get_kwargs()))

    main(args)
