import os
import glob
import numpy as np
import random

from torch.utils import data
from torchvision import transforms as T
from .tps_transformation import tps_transform
from PIL import Image, ImageOps
from .utils import elastic_transform
from iconflow.dataset.transforms import BatchRandomResizedCrop, BatchRandomTranspose


class DataSet(data.Dataset):

    def __init__(self, config, img_transform_gt, img_transform_sketch):
        dataset_root = config['TRAINING_CONFIG']['IMG_DIR']
        dataset_split = config['TRAINING_CONFIG']['MODE']
        image_size = config['MODEL_CONFIG']['IMG_SIZE']
        
        split = {
            'train': (0.0, 0.9),
            'test': (0.9, 1.0)
        }[dataset_split]
        
        self.img_transform_gt = img_transform_gt
        self.img_transform_sketch = img_transform_sketch
        self.img_dir = os.path.join(dataset_root, str(image_size))
        self.img_size = (image_size, image_size, 3)

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        
        contour_width = 5
        img_dir = os.path.join(dataset_root, str(image_size), 'img')
        contour_dir = os.path.join(dataset_root, str(image_size), f'contour{contour_width}')

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}
        
        self.image_size = image_size

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        )
        self.batch_transpose = BatchRandomTranspose()

        self.augment = config['TRAINING_CONFIG']['AUGMENT']

        self.dist = config['TRAINING_CONFIG']['DIST']
        if self.dist == 'uniform':
            self.a = config['TRAINING_CONFIG']['A']
            self.b = config['TRAINING_CONFIG']['B']
        else:
            self.mean = config['TRAINING_CONFIG']['MEAN']
            self.std = config['TRAINING_CONFIG']['STD']

    def get_icon(self, index):
        return Image.open(self.img_paths[self.keys[index]]).copy()
    
    def get_contour(self, index):
        return ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).copy()

    def __getitem__(self, index):
        reference = self.get_icon(index)
        contour = self.get_contour(index)

    def __getitem__(self, index):
        fid = self.keys[index]
        reference = Image.open(self.img_paths[self.keys[index]]).convert('RGB')
        sketch = ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).convert('L')
        
        (reference, sketch) = self.batch_resized_crop((reference, sketch))
        (reference, sketch) = self.batch_transpose((reference, sketch))
        
        assert (set(reference.size) | set(sketch.size)) == {self.image_size}

        if self.dist == 'uniform':
            noise = np.random.uniform(self.a, self.b, np.shape(reference))
        else:
            noise = np.random.normal(self.mean, self.std, np.shape(reference))

        reference = np.array(reference) + noise
        reference = Image.fromarray(reference.astype('uint8'))

        if self.augment == 'elastic':
            augmented_reference = elastic_transform(np.array(reference), 1000, 8, random_state=None)
            augmented_reference = Image.fromarray(augmented_reference)
        elif self.augment == 'tps':
            augmented_reference = tps_transform(np.array(reference))
            augmented_reference = Image.fromarray(augmented_reference)
        else:
            augmented_reference = reference

        return fid, self.img_transform_gt(augmented_reference), self.img_transform_gt(reference), self.img_transform_sketch(sketch)

    def __len__(self):
        """Return the number of images."""
        return len(self.keys)


def get_loader(config):

    img_transform_gt = list()
    img_transform_sketch = list()
    img_size = config['MODEL_CONFIG']['IMG_SIZE']

    img_transform_gt.append(T.Resize((img_size, img_size)))
    img_transform_gt.append(T.ToTensor())
    img_transform_gt.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    img_transform_gt = T.Compose(img_transform_gt)

    img_transform_sketch.append(T.Resize((img_size, img_size)))
    img_transform_sketch.append(T.ToTensor())
    img_transform_sketch.append(T.Normalize(mean=(0.5), std=(0.5)))
    img_transform_sketch = T.Compose(img_transform_sketch)

    dataset = DataSet(config, img_transform_gt, img_transform_sketch)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=config['TRAINING_CONFIG']['BATCH_SIZE'],
                                  shuffle=(config['TRAINING_CONFIG']['MODE'] == 'train'),
                                  num_workers=config['TRAINING_CONFIG']['NUM_WORKER'],
                                  drop_last=True)
    return data_loader
