import os
import random
import torch
import torchvision
import torchvision.datasets.imagenet
import torchvision.transforms as transforms
from utils import GaussianBlur
from rotated_image_folder import RotatedImageFolder

class RandomGaussianNoise(object):
    def __init__(self, sigma=0.1):
        self.sigma = sigma

    def __call__(self, tensor):
        if random.random() < 1.0:
            noise = torch.randn_like(tensor) * self.sigma
            return tensor + noise
        return tensor
    
class RandomRotationWithLabel:
    def __init__(self, args):
        self.angles = [0, 90, 180, 270]
        p0 = args.zero_prob
        other = (1 - p0) / 3
        self.probs = [p0, other, other, other]

    def __call__(self, img):
        i = random.choices(range(4), weights=self.probs, k=1)[0]
        angle = self.angles[i]
        img = transforms.functional.rotate(img, angle)
        return img, i
    
class ContrastiveLearningTransform:
    def __init__(self, args):
        self.args = args
        self.rot_transform = RandomRotationWithLabel(args)

        if args.pretrain_set=='stl10':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(args.imsize, scale=(0.2, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.4, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(kernel_size=9),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.43, 0.42, 0.39],
                                    std=[0.27, 0.26, 0.27])
            ])
        
        elif args.pretrain_set=='stl10-R':
                rotation_degrees = [0, 90, 180,270]

                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(args.imsize, scale=(0.2, 1.0)),
                    transforms.RandomChoice([
                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                        for angle in rotation_degrees
                    ]),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=9),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.43, 0.42, 0.39],
                                        std=[0.27, 0.26, 0.27])
                ])

        elif args.pretrain_set=='imagenet100':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(args.imsize, scale=(0.2, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.4, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(kernel_size=23),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])
        
        elif args.pretrain_set=='imagenet100':
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(args.imsize, scale=(0.2, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.4, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(kernel_size=23),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
            ])
        
        elif args.pretrain_set=='imagenet100-R':
                rotation_degrees = [0, 90, 180,270]

                self.transform = transforms.Compose([
                    transforms.RandomResizedCrop(args.imsize, scale=(0.2, 1.0)),
                    transforms.RandomChoice([
                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                        for angle in rotation_degrees
                    ]),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomApply(
                        [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                                saturation=0.4, hue=0.1)],
                        p=0.8
                    ),
                    transforms.RandomGrayscale(p=0.2),
                    GaussianBlur(kernel_size=23),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                ])
        
        elif args.pretrain_set == 'EMNIST' or args.pretrain_set == 'AffNIST':
            if self.args.random_rotation:
                self.transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    transforms.RandomResizedCrop(args.imsize, scale=(0.6, 1.0)),
                    transforms.RandomAffine(
                        degrees=0,
                        translate=(2/28, 2/28),
                        scale=(0.9, 1.1)
                    ),
                    transforms.RandomRotation(degrees=15),
                    transforms.ToTensor(),
                    RandomGaussianNoise(sigma=0.1),
                    transforms.Normalize((0.5,), (0.5,)),
                    transforms.RandomErasing(p=0.15, scale=(0.02, 0.15), value=0.0),
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    transforms.RandomResizedCrop(args.imsize, scale=(0.6, 1.0)),
                    transforms.RandomAffine(
                        degrees=0,
                        translate=(2/28, 2/28),
                        scale=(0.9, 1.1)
                    ),
                    transforms.ToTensor(),
                    RandomGaussianNoise(sigma=0.1),
                    transforms.Normalize((0.5,), (0.5,)),
                    transforms.RandomErasing(p=0.15, scale=(0.02, 0.15), value=0.0),
                ])
                
        elif args.pretrain_set == 'EMNIST-R' or args.pretrain_set=='AffNIST-R':
            if self.args.random_rotation:
                rotation_degrees = [0, 90, 180,270]
                self.transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    transforms.RandomResizedCrop(args.imsize, scale=(0.6, 1.0)),
                    transforms.RandomAffine(
                        degrees=0,
                        translate=(2/28, 2/28),
                        scale=(0.9, 1.1)
                    ),
                    transforms.RandomRotation(degrees=15),
                    transforms.RandomChoice([
                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                        for angle in rotation_degrees
                    ]),
                    transforms.ToTensor(),
                    RandomGaussianNoise(sigma=0.1),
                    transforms.Normalize((0.5,), (0.5,)),
                    transforms.RandomErasing(p=0.15, scale=(0.02, 0.15), value=0.0),
                ])
            else:
                rotation_degrees = [0, 90, 180,270]
                self.transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    transforms.RandomResizedCrop(args.imsize, scale=(0.6, 1.0)),
                    transforms.RandomAffine(
                        degrees=0,
                        translate=(2/28, 2/28),
                        scale=(0.9, 1.1)
                    ),
                    transforms.RandomChoice([
                        transforms.RandomRotation(degrees=[angle, angle], interpolation=transforms.InterpolationMode.BILINEAR) 
                        for angle in rotation_degrees
                    ]),
                    transforms.ToTensor(),
                    RandomGaussianNoise(sigma=0.1),
                    transforms.Normalize((0.5,), (0.5,)),
                    transforms.RandomErasing(p=0.15, scale=(0.02, 0.15), value=0.0),
                ])



    def __call__(self, x):
        if self.args.gie:
            if not self.args.horizontal_flip_first:
                x1, r1 = self.rot_transform(x)
                x2, r2 = self.rot_transform(x)
                y1 = self.transform(x1)
                y2 = self.transform(x2)
            else:
                x1 = self.transform(x)
                x2 = self.transform(x)
                
                y1, r1 = self.rot_transform(x1)
                y2, r2 = self.rot_transform(x2)

            return y1, y2, r1, r2
        else:
            y1 = self.transform(x)
            y2 = self.transform(x)
            
            r1 = 0
            r2 = 0 
            
            return y1, y2, r1, r2
        

def load_pretrain_datasets(args):

    if 'stl10' in args.pretrain_set:
        dataset = torchvision.datasets.ImageFolder(args.data / 'unlabeled', ContrastiveLearningTransform(args))
    elif 'imagenet100' in args.pretrain_set:
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))
    elif 'EMNIST' in args.pretrain_set or 'AffNIST' in args.pretrain_set:
        dataset = torchvision.datasets.ImageFolder(args.data / 'train', ContrastiveLearningTransform(args))

    return dataset


def load_eval_datasets(args):
        
    # set evaluation transform
    if 'stl10' in args.pretrain_set:
        mean = [0.43, 0.42, 0.39]
        std  = [0.27, 0.26, 0.27]
        transform = transforms.Compose([transforms.Resize(args.imsize),
                                        transforms.CenterCrop(args.imsize),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])

    elif 'imagenet100' in args.pretrain_set:
        mean = torch.tensor([0.485, 0.456, 0.406])
        std  = torch.tensor([0.229, 0.224, 0.225])
        transform = transforms.Compose([transforms.Resize(args.imsize),
                                        transforms.CenterCrop(args.imsize),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])
    
    elif 'EMNIST' in args.pretrain_set or 'AffNIST' in args.pretrain_set:
        mean = torch.tensor([0.5])
        std = torch.tensor([0.5])
        transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                        transforms.Resize(args.imsize),
                                        transforms.CenterCrop(args.imsize),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean, std)])

    # set evaluation dataset
    if args.eval_set == 'stl10':
        num_classes=10
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)
            
    elif args.eval_set == 'imagenet100':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'val', transform)

    elif args.eval_set == 'imagenet1k':
        num_classes=1000
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'val', transform)
            
    elif args.eval_set == 'stanford_cars':
        num_classes=196
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)

    elif args.eval_set == 'fgvc_aircraft':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'val', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'val', transform)
            
    elif args.eval_set == 'cub_200_2011':
        num_classes=200
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)
    
    elif args.eval_set == 'cifar10':
        num_classes=10
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)
    
    elif args.eval_set == 'cifar100':
        num_classes=100
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)
            
    elif args.eval_set == 'caltech256':
        num_classes=256
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)

    elif args.eval_set == 'mtarsi':
        num_classes=20
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)

    elif args.eval_set == 'oxford_flowers':
        num_classes=102
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'valid', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'valid', transform)
    
    elif args.eval_set == 'EMNIST':
        if args.emnist_type == 'byclass':
            num_classes = 62
        elif args.emnist_type == 'balanced':
            num_classes = 47
        
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'valid', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'valid', transform)
    
    elif args.eval_set == 'RotNIST':
        num_classes = 10
        train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
        val_dataset = torchvision.datasets.ImageFolder(args.data / 'valid', transform)
    
    elif args.eval_set == 'AffNIST':
        num_classes = 10
        if not args.rotated:
            train_dataset = torchvision.datasets.ImageFolder(args.data / 'train', transform)
            val_dataset = torchvision.datasets.ImageFolder(args.data / 'test', transform)
        else:
            train_dataset = RotatedImageFolder(args.data / 'train', transform)
            val_dataset = RotatedImageFolder(args.data / 'test', transform)

    return train_dataset, val_dataset, num_classes




