import os
import glob
import random
import numpy as np
from PIL import Image, ImageOps, ImageDraw

import torch
from torch.utils.data import Dataset

import torchvision.transforms as transforms
import torchvision.transforms.functional as T

from iconflow.dataset.transforms import (
    RandomTranspose
)


class MyDataset(Dataset):
    def __init__(
        self,
        root,
        folder,
        image_size,
        random_crop=False,
        random_transpose=False,
        split=(0, 1)
    ):
        root = os.path.expanduser(root)
        self.folder = folder
        img_dir = os.path.join(root, str(image_size), {'icon': 'img', 'contour': 'contour5'}[self.folder])
        img_paths = {
            os.path.splitext(os.path.basename(path))[0]: path
            for path in glob.glob(os.path.join(img_dir, '*.png'))
        }
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        
        self.keys = keys
        self.img_paths = {key: img_paths[key] for key in keys}
        
        self.resized_crop = transforms.RandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0), T.InterpolationMode.BICUBIC
        ) if random_crop else transforms.Resize((image_size, image_size), T.InterpolationMode.BICUBIC)
        self.transpose = random_transpose and RandomTranspose()
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        
    def __getitem__(self, index):
        img = Image.open(self.img_paths[self.keys[index]])
        if self.folder == 'icon':
            img = img.convert('RGB')
        elif self.folder == 'contour':
            img = ImageOps.invert(img).convert('RGB')
        img = self.resized_crop(img)
        img = self.transpose(img) if self.transpose else img
        img = self.to_tensor(img)
        img = self.normalize(img)
        return img
    
    def __len__(self):
        return len(self.keys)            
