from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import logging

import torch
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import torch.distributed as dist
from torch.nn import functional as F

logger = logging.getLogger(__name__)

"""Validation code was borrowed from https://github.com/pytorch/vision"""

def reduce_tensor(inp):
    """
    Reduce the loss from all processes so that
    process with rank 0 has the averaged results.
    """
    world_size = get_world_size()
    if world_size < 2:
        return inp
    with torch.no_grad():
        reduced_inp = inp
        dist.reduce(reduced_inp, dst=0)
    return reduced_inp


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def collect_statistics(model, dataloader, device):
    model.eval()


    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(tqdm(dataloader)):
            if i == 100:
                break
            target = target.to(device)
            images = images.to(device)
            # compute output
            output = model(images,True)
            if i % 200 == 199:
                print("{} batches proceeded from {}".format(i, len(dataloader)))
    model.proccess_model_statistics()


def validate(model, model_full, binOp, dataloader, criterion, device, logger):
    model.eval()
    if model_full is not None:
        model_full.eval()
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(dataloader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    if binOp is not None:
        binOp.binarization()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(tqdm(dataloader)):
            # if i == 50:
            #     break


            target = target.to(device)
            images = images.to(device)
            # compute output
            output = model(images, True)
            loss = criterion(output, target)
            # print(target)
            if model_full is not None:
                output_full = model_full(images)
                acc1, acc5 = accuracy_full(output, output_full, topk=(1, 5))
            # measure accuracy and record loss
            else:
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % 20 == 0:
                progress.display(i)

    if binOp is not None:
        binOp.restore()
    model.dump_all_dense_fractions()
    model.proccess_model_statistics()
    msg = ' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)
    print(msg)
    logger.info(msg)

    return top1.avg, top5.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def accuracy_full(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)

        batch_size = target.size(0)
        _,test_pred = output.topk(maxk, 1, True, True)
        # print("pred_quant = ", test_pred.t()[0])

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()

        _, pred_full = target.topk(1, 1, True, True)
        pred_full = pred_full.t()[0]
        correct = pred.eq(pred_full.view(1, -1).expand_as(pred))
        # print("pred_dull = ", pred_full)
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

