import torch.utils.data as data
import numpy as np

from PIL import Image

import os
import os.path
import sys
import random
import torch
from torchvision import transforms

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

def has_file_allowed_extension(filename, extensions):
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def is_image_file(filename):
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(class_to_idx.keys()):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class AuxDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None, num_classes=None, do_rotations=False,
                 return_index=False):
        classes, class_to_idx = self._find_classes(root)
        samples = make_dataset(root, class_to_idx, IMG_EXTENSIONS)
        if len(samples) == 0:
            raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                               "Supported extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.loader = pil_loader

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.data = [s[0] for s in samples]
        self.targets = [s[1] for s in samples]

        self.transform = transform
        self.target_transform = target_transform

        self.num_classes = num_classes
        self.do_rotations = do_rotations
        self.rotations = [0, 90, 180, 270]
        self.return_index = return_index


    def _find_classes(self, dir):
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    def rotate_img(self, img, rot):
        if rot == 0:  # 0 degrees rotation
            lab = 0
            return img, lab
        elif rot == 90:  # 90 degrees rotation
            lab = 1
            return img.transpose(Image.ROTATE_90), lab
        elif rot == 180:  # 90 degrees rotation
            lab = 2
            return img.transpose(Image.ROTATE_180), lab
        elif rot == 270:  # 270 degrees rotation / or -90
            lab = 3
            return img.transpose(Image.ROTATE_270), lab
        else:
            raise ValueError('rotation should be 0, 90, 180, or 270 degrees')


    def __getitem__(self, index):
        # path = self.data[index]
        path, _ = self.samples[index]
        # target = self.targets[index]
        sample = self.loader(path)
        assert self.num_classes is not None
        # Uniform random assignment of labels for aux images
        target = np.random.randint(low=0, high=self.num_classes)
        dom = 0

        if self.do_rotations is True:
            random.shuffle(self.rotations)
            r_imgs = []
            r_targets = []
            for rot in self.rotations:
                r_img, r_target = self.rotate_img(sample, rot)
                r_img = self.transform(r_img)
                r_imgs.append(r_img)
                r_targets.append(r_target)

            r_targets = torch.LongTensor(r_targets)
            r_imgs = torch.stack(r_imgs, dim=0)

            return r_imgs, torch.LongTensor([target] * len(self.rotations)), r_targets, \
                   torch.LongTensor([dom] * len(self.rotations))

        else:
            if self.transform is not None:
                sample = self.transform(sample)
            if self.return_index:
                return sample, target, index

            return sample, target


    def __len__(self):
        return len(self.samples)

