import os, glob
from torch.utils.data import Dataset
from iconflow.dataset.transforms import *
from PIL import ImageOps, Image


class UnpairedIconContourDataset(Dataset):
    def __init__(
        self,
        root,
        image_size,
        random_crop=False,
        random_transpose=False,
        split=(0, 1),
        contour_width=5,
    ):

        root = os.path.expanduser(root)
        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), f'contour{contour_width}')

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        ) if random_crop else TupleTransform([
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
        ])
        self.batch_transpose = random_transpose and BatchRandomTranspose()
        self.img_or_contour_to_tensor = transforms.ToTensor()
        self.img_or_contour_normalize = transforms.Normalize(0.5, 0.5)

    def get_icon(self, index):
        return Image.open(self.img_paths[self.keys[index]]).convert('RGB')
    
    def get_contour(self, index):
        return ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).convert('RGB')

    def __getitem__(self, _):
        img = self.get_icon(random.randrange(len(self.keys)))
        contour = self.get_contour(random.randrange(len(self.keys)))

        img, contour = self.batch_resized_crop((img, contour))
        if self.batch_transpose:
            img, contour = self.batch_transpose((img, contour))

        img, contour = map(self.img_or_contour_to_tensor, (img, contour))
        img, contour = map(self.img_or_contour_normalize, (img, contour))

        return contour, img

    def __len__(self):
        return len(self.keys)


if __name__ == '__main__':
    dataset = UnpairedIconContourDataset('datasets/icon4/data', 128, True, True, False, (0.0, 0.9))
    
    