import logging
import math
import torch

import os
import numpy as np
import random
from PIL import Image
from torchvision import datasets
from torchvision import transforms
from src.AuxMix.dataset.auxdataset import AuxDataset
from torch.utils.data.dataloader import default_collate

from .randaugment import RandAugmentMC
# from .randaugment_uda import RandAugment

logger = logging.getLogger(__name__)

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
cifar100_mean = (0.5071, 0.4867, 0.4408)
cifar100_std = (0.2675, 0.2565, 0.2761)
normal_mean = (0.5, 0.5, 0.5)
normal_std = (0.5, 0.5, 0.5)

classes_names = ('plane', 'car', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck')
cifar10_animals_cls_idxs = [2, 3, 4, 5, 6, 7]
cifar10_others_cls_idxs = [0, 1, 8, 9]


def get_cifar10(args, root):
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_split(
        args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)

    train_unlabeled_dataset = CIFAR10SSL(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))

    test_dataset = datasets.CIFAR10(
        root, train=False, transform=transform_val, download=False)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset


def get_cifar10_auxUL(args, root, do_rotations=False, return_index=False):
    aux_dir = args.aux_datapath
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_aux_split(
        args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled, do_rotations=do_rotations)

    train_unlabeled_dataset = CIFAR10SSL(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std), do_rotations=do_rotations,
        return_index=return_index, num_classes=len(train_labeled_dataset.class_to_idx))

    train_aux_dataset = AuxDataset(aux_dir, transform=transform_labeled, do_rotations=do_rotations,
                                   num_classes=len(train_labeled_dataset.class_to_idx),
                                   return_index=return_index)

    train_unlabeled_dataset = combine_datasets([train_unlabeled_dataset, train_aux_dataset])

    # test_dataset = datasets.CIFAR10(root, train=False, transform=transform_val, download=False)
    test_dataset = CIFAR10SSL(root, indexs=None, train=False, transform=transform_val, download=False,
                              do_rotations=do_rotations)

    return train_labeled_dataset, train_unlabeled_dataset, train_aux_dataset, test_dataset


def get_cifar10_animals_others(args, root, do_rotations=False, return_index=False):

    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs, train_aux_idxs = x_u_animals_others_split(
        args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL_Animals_Others(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled, do_rotations=do_rotations)

    train_unlabeled_dataset = CIFAR10SSL_Animals_Others(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std), do_rotations=do_rotations,
        return_index=return_index, num_classes=len(train_labeled_dataset.class_to_idx))

    train_aux_dataset = CIFAR10SSL_Animals_Others(
        root, train_aux_idxs, train=True,
        transform=transform_labeled, do_rotations=do_rotations,
        return_index=return_index, num_classes=len(train_labeled_dataset.class_to_idx))

    train_unlabeled_dataset = combine_datasets([train_unlabeled_dataset, train_aux_dataset])

    # test_dataset = datasets.CIFAR10(root, train=False, transform=transform_val, download=False)
    # Split test set to get only animals classes
    base_test_dataset = datasets.CIFAR10(root, train=False, download=False)
    test_idxs = x_u_animals_others_split_test(args, base_test_dataset.targets)
    test_dataset = CIFAR10SSL_Animals_Others(root, test_idxs, train=False, transform=transform_val,
                                             do_rotations=do_rotations)

    # test_dataset = CIFAR10SSL(root, indexs=None, train=False, transform=transform_val, download=False,
    #                           do_rotations=do_rotations)

    return train_labeled_dataset, train_unlabeled_dataset, train_aux_dataset, test_dataset


def x_u_split(args, labels):
    np.random.seed(args.seed)
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    unlabeled_idx = np.array(range(len(labels)))
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    return labeled_idx, unlabeled_idx


def x_u_aux_split(args, labels):
    np.random.seed(args.seed)
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    # unlabeled_idx = np.array(range(len(labels)))
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled

    unlabeled_idx = labeled_idx

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    np.random.shuffle(unlabeled_idx)
    return labeled_idx, unlabeled_idx


def x_u_animals_others_split(args, labels):
    np.random.seed(args.seed)
    # Animals are labeled
    # Others are unlabeled
    label_per_class = args.num_labeled // len(cifar10_animals_cls_idxs)
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    # unlabeled_idx = np.array(range(len(labels)))
    for i in cifar10_animals_cls_idxs:
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    # Add labeled data as unlabeled data too
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled
    unlabeled_idx = labeled_idx

    aux_idx = []
    for i in cifar10_others_cls_idxs:
        idx = np.where(labels == i)[0]
        aux_idx.extend(idx)

    aux_idx = np.array(aux_idx)
    unlabeled_idx = np.array(unlabeled_idx)

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    np.random.shuffle(unlabeled_idx)
    np.random.shuffle(aux_idx)
    return labeled_idx, unlabeled_idx, aux_idx


def x_u_animals_others_split_test(args, labels):
    # Animals are labeled
    # Others are unlabeled
    labels = np.array(labels)
    test_idx = []
    for i in cifar10_animals_cls_idxs:
        idx = np.where(labels == i)[0]
        test_idx.extend(idx)
    test_idx = np.array(test_idx)

    np.random.shuffle(test_idx)
    return test_idx


class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.Resize(size=(32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.Resize(size=(32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)


def collate_fn(batch):
    batch = default_collate(batch)
    if len(batch) == 4:
        if isinstance(batch[0], list):
            batch_size, rotations, channels, height, width = batch[0][0].size()
            batch[0][0] = batch[0][0].view([batch_size * rotations, channels, height, width])
            batch[0][1] = batch[0][1].view([batch_size * rotations, channels, height, width])
        else:
            batch_size, rotations, channels, height, width = batch[0].size()
            batch[0] = batch[0].view([batch_size * rotations, channels, height, width])
        batch[1] = batch[1].view([batch_size * rotations])
        batch[2] = batch[2].view([batch_size * rotations])
        batch[3] = batch[3].view([batch_size * rotations])

    return batch


class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False, do_rotations=False, return_index=False, num_classes=None):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        self.rotations = [0, 90, 180, 270]
        self.do_rotations = do_rotations
        self.return_index = return_index
        self.num_classes = num_classes
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        if isinstance(self.data[index], str):
            with open(self.data[index], 'rb') as f:
                img = Image.open(f)
                img = img.convert('RGB')
            # target = self.targets[index]
            # Auxiliary data
            assert self.num_classes is not None
            target = np.random.randint(low=0, high=self.num_classes)
            dom = 0
        else:
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            dom = 1

        if self.do_rotations is True:
            random.shuffle(self.rotations)
            r_imgs = []
            r_targets = []
            for rot in self.rotations:
                r_img, r_target = self.rotate_img(img, rot)
                r_img = self.transform(r_img)
                r_imgs.append(r_img)
                r_targets.append(r_target)

            r_targets = torch.LongTensor(r_targets)
            if isinstance(r_imgs[0], tuple):
                r_imgs_w = torch.stack([i[0] for i in r_imgs], dim=0)
                r_imgs_s = torch.stack([i[1] for i in r_imgs], dim=0)
                r_imgs = [r_imgs_w, r_imgs_s]
            else:
                r_imgs = torch.stack(r_imgs, dim=0)

            if self.target_transform is not None:
                target = self.target_transform(target)
            return r_imgs, torch.LongTensor([target] * len(self.rotations)), r_targets, \
                   torch.LongTensor([dom] * len(self.rotations))
        else:
            if self.transform is not None:
                img = self.transform(img)

            if self.target_transform is not None:
                target = self.target_transform(target)

            if self.return_index:
                return img, target, index

            return img, target

    def rotate_img(self, img, rot):
        if rot == 0:  # 0 degrees rotation
            lab = 0
            return img, lab
        elif rot == 90:  # 90 degrees rotation
            lab = 1
            return img.transpose(Image.ROTATE_90), lab
        elif rot == 180:  # 90 degrees rotation
            lab = 2
            return img.transpose(Image.ROTATE_180), lab
        elif rot == 270:  # 270 degrees rotation / or -90
            lab = 3
            return img.transpose(Image.ROTATE_270), lab
        else:
            raise ValueError('rotation should be 0, 90, 180, or 270 degrees')


class CIFAR10SSL_Animals_Others(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False, do_rotations=False, return_index=False, num_classes=None):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        self.rotations = [0, 90, 180, 270]
        self.do_rotations = do_rotations
        self.return_index = return_index
        self.num_classes = num_classes
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        # Returns 4 things if rotations are enables
        # Return 3 things is get_index
        # Return 2 things (img, target) otherwise
        # We use the number of items returned to collate each output differently

        if isinstance(self.data[index], str):
            with open(self.data[index], 'rb') as f:
                img = Image.open(f)
                img = img.convert('RGB')
            # target = self.targets[index]
            # Auxiliary data
            assert self.num_classes is not None
            target = np.random.randint(low=0, high=self.num_classes)
            dom = 0
        else:
            img, target = self.data[index], self.targets[index]
            img = Image.fromarray(img)
            dom = 1

        target = target - 2

        if self.do_rotations is True:
            random.shuffle(self.rotations)
            r_imgs = []
            r_targets = []
            for rot in self.rotations:
                r_img, r_target = self.rotate_img(img, rot)
                r_img = self.transform(r_img)
                r_imgs.append(r_img)
                r_targets.append(r_target)

            r_targets = torch.LongTensor(r_targets)
            if isinstance(r_imgs[0], tuple):
                r_imgs_w = torch.stack([i[0] for i in r_imgs], dim=0)
                r_imgs_s = torch.stack([i[1] for i in r_imgs], dim=0)
                r_imgs = [r_imgs_w, r_imgs_s]
            else:
                r_imgs = torch.stack(r_imgs, dim=0)

            if self.target_transform is not None:
                target = self.target_transform(target)

            return r_imgs, torch.LongTensor([target] * len(self.rotations)), r_targets, \
                   torch.LongTensor([dom] * len(self.rotations))
        else:
            if self.transform is not None:
                img = self.transform(img)

            if self.target_transform is not None:
                target = self.target_transform(target)

            if self.return_index:
                return img, target, index

            return img, target

    def rotate_img(self, img, rot):
        if rot == 0:  # 0 degrees rotation
            lab = 0
            return img, lab
        elif rot == 90:  # 90 degrees rotation
            lab = 1
            return img.transpose(Image.ROTATE_90), lab
        elif rot == 180:  # 90 degrees rotation
            lab = 2
            return img.transpose(Image.ROTATE_180), lab
        elif rot == 270:  # 270 degrees rotation / or -90
            lab = 3
            return img.transpose(Image.ROTATE_270), lab
        else:
            raise ValueError('rotation should be 0, 90, 180, or 270 degrees')


def combine_datasets(dataset_list):
    dataset_list[0].data = list(dataset_list[0].data)
    dataset_list[0].targets = list(dataset_list[0].targets)
    for dataset in dataset_list[1:]:
        dataset_list[0].data.extend(list(dataset.data))
        dataset_list[0].targets.extend(list(dataset.targets))

    return dataset_list[0]


def show_sample(sample_loader):
    data_iter = iter(sample_loader)
    images, labels = data_iter.next()

    unnorm = UnNormalize(mean=cifar10_mean, std=cifar100_std)
    X1 = unnorm(images[0]).numpy().transpose([0, 2, 3, 1])

    plot_images(X1, labels)


def plot_images(images, cls_true, cls_pred=None):
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(3, 3)

    for i, ax in enumerate(axes.flat):
        # plot img
        ax.imshow(images[i, :, :, :], interpolation='spline16')

        # show true & predicted classes
        # class_true = int(cls_true[i])
        # cls_true_name = self.class_to_name[class_true]
        # cls_true_name = class_true
        # if cls_pred is None:
        xlabel = "{}".format(cls_true[i])
        # else:
        #     class_pred = self.idx_to_class[int(cls_pred[i])]
        #     cls_pred_name = self.class_map_dict[class_pred]
        #     xlabel = "True: {0}\nPred: {1}".format(
        #         cls_true_name, cls_pred_name
        #     )
        ax.set_xlabel(xlabel)
        ax.set_xticks([])
        ax.set_yticks([])

    plt.show()
    exit(0)

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

# DATASET_GETTERS = {'cifar10': get_cifar10,
#                    'cifar100': get_cifar100}

DATASET_GETTERS = {'cifar10': get_cifar10_auxUL,
                   'cifar10-animals-others': get_cifar10_animals_others}


def show_imgs(dataset):
    from torch.utils.data import DataLoader
    from torch.utils.data import RandomSampler

    unlabeled_trainloader = DataLoader(dataset,
                                        sampler=RandomSampler(dataset),
                                        batch_size=90,
                                        num_workers=0,
                                        drop_last=True)

    show_sample(unlabeled_trainloader)

    # sample_loader = dataset.lab_loader(batch_size=9, num_workers=0)
    # dataset.show_sample(sample_loader)




