import os
import glob
import random
import numpy as np
from PIL import Image, ImageOps, ImageDraw

import torch
from torch.utils.data import Dataset

import torchvision.transforms as transforms
import torchvision.transforms.functional as T

from iconflow.dataset.transforms import (
    BatchRandomResizedCrop,
    BatchRandomHue,
    BatchRandomTranspose,
    TupleTransform
)


class IconContourDataset(Dataset):
    def __init__(
        self,
        root,
        random_crop=False,
        random_transpose=False,
        split=(0, 1),
        image_size=256,
        output_size=224,
        contour_width=5
    ):

        root = os.path.expanduser(root)
        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), f'contour{contour_width}')
        self.output_size = output_size

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        ) if random_crop else TupleTransform([
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
        ])
        self.batch_transpose = random_transpose and BatchRandomTranspose()
        self.img_or_contour_to_tensor = transforms.ToTensor()

    def get_icon(self, index):
        return Image.open(self.img_paths[self.keys[index]]).copy()
    
    def get_contour(self, index):
        return ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).copy()

    def __getitem__(self, index):
        img = self.get_icon(index)
        contour = self.get_contour(index)

        img, contour = self.batch_resized_crop((img, contour))
        if self.batch_transpose:
            img, contour = self.batch_transpose((img, contour))

        img = img.resize((self.output_size, self.output_size), Image.BICUBIC)
        contour = contour.resize((self.output_size, self.output_size), Image.BICUBIC)

        img, contour = map(self.img_or_contour_to_tensor, (img, contour))

        return img, contour

    def __len__(self):
        return len(self.keys)


if __name__ == '__main__':
    dataset = IconContourDataset('datasets/icon4/data/in_memory')
    import ei; ei.embed()
