import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch import nn


class KNNValidation(object):
    def __init__(self, args, model, model_p2=None, K=1):
        self.model = model
        self.model_p2 = model_p2
        self.device = torch.device('cuda' if next(model.parameters()).is_cuda else 'cpu')
        self.args = args
        self.K = K

        base_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if args.dataset_type == "CIFAR10":
            train_dataset = datasets.CIFAR10(root=args.data_root,
                                             train=True,
                                             download=True,
                                             transform=base_transforms)
            val_dataset = datasets.CIFAR10(root=args.data_root,
                                           train=False,
                                           download=True,
                                           transform=base_transforms)
        elif args.dataset_type == "CIFAR100":
            train_dataset = datasets.CIFAR100(root=args.data_root,
                                             train=True,
                                             download=True,
                                             transform=base_transforms)
            val_dataset = datasets.CIFAR100(root=args.data_root,
                                           train=False,
                                           download=True,
                                           transform=base_transforms)
        elif args.dataset_type == "STL10":
            train_dataset = datasets.STL10(root=args.data_root,
                                         split="train",
                                         download=True,
                                         transform=base_transforms)
            val_dataset = datasets.STL10(root=args.data_root,
                                         split="test",
                                         download=True,
                                         transform=base_transforms)

        self.train_dataloader = DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           drop_last=True)

        self.val_dataloader = DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True,
                                         drop_last=True)

    def _topk_retrieval(self):
        """Extract features from validation split and search on train split features."""
        n_data = self.train_dataloader.dataset.data.shape[0]
        feat_dim = self.args.feat_dim

        self.model.eval()
        if self.args.net == "SepCL_v1_OTL1" or self.args.net == "SepCL_v1_Simsiam_OTL1":
            self.model_p2.eval()
        if str(self.device) == 'cuda':
            torch.cuda.empty_cache()

        train_features = torch.zeros([feat_dim, n_data], device=self.device)
        with torch.no_grad():
            for batch_idx, (inputs, _) in enumerate(self.train_dataloader):
                inputs = inputs.to(self.device)
                batch_size = inputs.size(0)

                # forward
                features = self.model(inputs)
                if self.args.net == "SepCL_v1_OTL1" or self.args.net == "SepCL_v1_Simsiam_OTL1":
                    if self.args.arch == 'resnet18':
                        if self.args.dataset_type == 'STL10':
                            if self.args.last_dim == -1:
                                out_dim = 4608
                            else:
                                out_dim = self.args.last_dim
                        else:
                            if self.args.last_dim == -1:
                                out_dim = 512
                            else:
                                out_dim = self.args.last_dim
                    size1 = int(round(out_dim * self.args.hp1))
                    features = features[:,:size1]
                    features = self.model_p2(features)
                features = nn.functional.normalize(features)
                train_features[:, batch_idx * batch_size:batch_idx * batch_size + batch_size] = features.data.t()

            if self.args.dataset_type == "STL10":
                train_labels = torch.LongTensor(self.train_dataloader.dataset.labels).cuda()
            else:
                train_labels = torch.LongTensor(self.train_dataloader.dataset.targets).cuda()

        total = 0
        correct = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.val_dataloader):
                targets = targets.cuda(non_blocking=True)
                batch_size = inputs.size(0)
                features = self.model(inputs.to(self.device))
                if self.args.net == "SepCL_v1_OTL1" or self.args.net == "SepCL_v1_Simsiam_OTL1":
                    if self.args.arch == 'resnet18':
                        if self.args.dataset_type == 'STL10':
                            if self.args.last_dim == -1:
                                out_dim = 4608
                            else:
                                out_dim = self.args.last_dim
                        else:
                            if self.args.last_dim == -1:
                                out_dim = 512
                            else:
                                out_dim = self.args.last_dim
                    size1 = int(round(out_dim * self.args.hp1))
                    features = features[:,:size1]
                    features = self.model_p2(features)

                dist = torch.mm(features, train_features)
                yd, yi = dist.topk(self.K, dim=1, largest=True, sorted=True)
                candidates = train_labels.view(1, -1).expand(batch_size, -1)
                retrieval = torch.gather(candidates, 1, yi)

                retrieval = retrieval.narrow(1, 0, 1).clone().view(-1)

                total += targets.size(0)
                correct += retrieval.eq(targets.data).sum().item()
        top1 = correct / total

        return top1

    def eval(self):
        return self._topk_retrieval()


