import os
import torch.utils.data as data
from torch.utils.data import Dataset
from PIL import Image
from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize
import random

import glob
from iconflow.dataset.transforms import *
from PIL import ImageOps


class SingleDataset(Dataset):
    def __init__(
        self,
        dataset_root,
        phase,
        opts,
        setname,
        input_dim,
        resize_size,
        crop_size
    ):
        self.dataset_root = dataset_root
        images = os.listdir(os.path.join(self.dataset_root, phase + setname))
        self.img = [os.path.join(self.dataset_root, phase + setname, x) for x in images]
        self.size = len(self.img)
        self.input_dim = input_dim

        # setup image transformation
        transforms = [Resize((resize_size, resize_size), Image.BICUBIC)]
        transforms.append(CenterCrop(crop_size))
        transforms.append(ToTensor())
        transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
        self.transforms = Compose(transforms)
        print('%s: %d images'%(setname, self.size))
        return

    def __getitem__(self, index):
        data = self.load_img(self.img[index], self.input_dim)
        return data

    def load_img(self, img_name, input_dim):
        img = Image.open(img_name).convert('RGB')
        img = self.transforms(img)
        if input_dim == 1:
            img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114
            img = img.unsqueeze(0)
        return img

    def __len__(self):
        return self.size

# class UnpairedDataset(data.Dataset):
#     def __init__(
#         self,
#         dataset_root,
#         phase,
#         input_dim_a,
#         input_dim_b,
#         resize_size,
#         crop_size,
#         no_flip
#     ):
#         self.dataset_root = dataset_root

#         # A
#         images_A = os.listdir(os.path.join(self.dataset_root, phase + 'A'))
#         self.A = [os.path.join(self.dataset_root, phase + 'A', x) for x in images_A]

#         # B
#         images_B = os.listdir(os.path.join(self.dataset_root, phase + 'B'))
#         self.B = [os.path.join(self.dataset_root, phase + 'B', x) for x in images_B]

#         self.A_size = len(self.A)
#         self.B_size = len(self.B)
#         self.dataset_size = max(self.A_size, self.B_size)
#         self.input_dim_A = input_dim_a
#         self.input_dim_B = input_dim_b

#         # setup image transformation
#         transforms = [Resize((resize_size, resize_size), Image.BICUBIC)]
#         if phase == 'train':
#             transforms.append(RandomCrop(crop_size))
#         else:
#             transforms.append(CenterCrop(crop_size))
#         if not no_flip:
#             transforms.append(RandomHorizontalFlip())
#         transforms.append(ToTensor())
#         transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
#         self.transforms = Compose(transforms)
#         print('A: %d, B: %d images'%(self.A_size, self.B_size))
#         return

#     def __getitem__(self, index):
#         if self.dataset_size == self.A_size:
#             data_A = self.load_img(self.A[index], self.input_dim_A)
#             data_B = self.load_img(self.B[random.randint(0, self.B_size - 1)], self.input_dim_B)
#         else:
#             data_A = self.load_img(self.A[random.randint(0, self.A_size - 1)], self.input_dim_A)
#             data_B = self.load_img(self.B[index], self.input_dim_B)
#         return data_A, data_B

#     def load_img(self, img_name, input_dim):
#         img = Image.open(img_name).convert('RGB')
#         img = self.transforms(img)
#         if input_dim == 1:
#             img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114
#             img = img.unsqueeze(0)
#         return img

#     def __len__(self):
#         return self.dataset_size

class UnpairedIconContourDataset(Dataset):
    def __init__(
        self,
        root,
        image_size,
        random_crop=False,
        random_transpose=False,
        split=(0, 1),
        contour_width=5,
    ):

        root = os.path.expanduser(root)
        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), f'contour{contour_width}')

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        ) if random_crop else TupleTransform([
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
        ])
        self.batch_transpose = random_transpose and BatchRandomTranspose()
        self.img_or_contour_to_tensor = transforms.ToTensor()
        self.img_or_contour_normalize = transforms.Normalize(0.5, 0.5)

    def get_icon(self, index):
        return Image.open(self.img_paths[self.keys[index]]).convert('RGB')
    
    def get_contour(self, index):
        return ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).convert('RGB')

    def __getitem__(self, _):
        img = self.get_icon(random.randrange(len(self.keys)))
        contour = self.get_contour(random.randrange(len(self.keys)))

        img, contour = self.batch_resized_crop((img, contour))
        if self.batch_transpose:
            img, contour = self.batch_transpose((img, contour))

        img, contour = map(self.img_or_contour_to_tensor, (img, contour))
        img, contour = map(self.img_or_contour_normalize, (img, contour))

        return contour, img

    def __len__(self):
        return len(self.keys)
    
