import os
import glob
import random
from PIL import Image, ImageOps
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.transforms import functional as tvF

class RandomResizedCrop(transforms.RandomResizedCrop):
    def __init__(self, size, scale=(0.84, 0.9), ratio=(1.0, 1.0),
                 interpolation=transforms.InterpolationMode.BICUBIC):
        super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation)

    def __call__(self, img1, img2):
        assert img1.size == img2.size
        # fix parameter
        i, j, h, w = self.get_params(img1, self.scale, self.ratio)
        # return the image with the same transformation

        img1 = tvF.resized_crop(img1, i, j, h, w, self.size, self.interpolation)
        img2 = tvF.resized_crop(img2, i, j, h, w, self.size, self.interpolation)
        return img1, img2

class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
    def __call__(self, img1, img2):
        assert img1.size == img2.size

        p = random.random()
        if p < 0.5:
                img1 = tvF.hflip(img1)
        if p < 0.5:
                img2 = tvF.hflip(img2)
        return img1, img2

class RandomVerticalFlip(transforms.RandomVerticalFlip):
    def __call__(self, img1, img2):
        assert img1.size == img2.size

        p = random.random()
        if p < 0.5:
                img1 = tvF.vflip(img1)
        if p < 0.5:
                img2 = tvF.vflip(img2)
        return img1, img2


class IconDataset(Dataset):
    def __init__(self, root, image_size=128, pad_ratio=8, split=(0.9, 1.0)):
        self.image_size = image_size
        self.min_crop_area = ((pad_ratio+1) / (pad_ratio+2)) ** 2

        root = os.path.expanduser(root)
        contour_width = 5
        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), f'contour{contour_width}')

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys
    
        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}
        
        self.idxs = list(map(int, keys))
        self.idx2key = dict(zip(self.idxs, keys))
        
        idxs_set = set(self.idxs)
        label_group = torch.load(os.path.join(root, 'labels.pt'))
        assert len(label_group['labels']) == len(img_paths)
        self.labels = {idx: label for idx, label in enumerate(label_group['labels']) if idx in idxs_set}
        self.groups = [[label for label in group if label in idxs_set] for group in label_group['groups']]

        self.style_img_aug = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(self.min_crop_area, 1.0), ratio=(1.0, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
        ])

        self.paired_aug = [
            RandomResizedCrop(image_size, scale=(self.min_crop_area, 1.0), ratio=(1.0, 1.0)),
            RandomHorizontalFlip(),
            RandomVerticalFlip(),
        ]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, i):
        """
        returns s1, s2, s3, contour such that
        s1, s2 are in the same cluster
        s3, contour are paired icon and it's contour
        note that s3 can be in different cluster
        """
        
        idx1 = self.idxs[i]
        
        label = self.labels[idx1]
        group = self.groups[label]

        # pick the icon in the same color cluster
        idx2 = random.choice(group)
        idx3 = random.choice(self.idxs)
        
        key1, key2, key3 = map(self.idx2key.__getitem__, (idx1, idx2, idx3))

        s1 = Image.open(self.img_paths[key1]).convert('RGB')
        s2 = Image.open(self.img_paths[key2]).convert('RGB')
        s3 = Image.open(self.img_paths[key3]).convert('RGB')
        contour = ImageOps.invert(Image.open(self.contour_paths[key3]).convert('L'))

        s1 = self.style_img_aug(s1)
        s2 = self.style_img_aug(s2)

        for aug in self.paired_aug:
            s3, contour = aug(s3, contour)

        s1 = tvF.to_tensor(s1)
        s2 = tvF.to_tensor(s2)
        s3 = tvF.to_tensor(s3)
        contour = tvF.to_tensor(contour)

        return s1, s2, s3, contour[:1, :, :]
