import os
import glob
import random
import numpy as np
from PIL import Image, ImageOps, ImageDraw
import pickle
import torch
from torch.utils.data import Dataset
from utils.style import load_style_lists
import torchvision.transforms as transforms
import torchvision.transforms.functional as T

from .transforms import (
    BatchRandomResizedCrop,
    BatchRandomHue,
    BatchRandomTranspose,
    RandomTranspose,
    TupleTransform
)


class IconContourDataset(Dataset):
    def __init__(
        self,
        root,
        image_size,
        random_crop=False,
        random_transpose=False,
        random_color=False,
        split=(0, 1),
        legacy_normalize=False,
        as_pil_image=False,
    ):

        root = os.path.expanduser(root)
        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), 'contour')

        def get_key(path):
            return os.path.splitext(os.path.basename(path))[0]
        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        ) if random_crop else TupleTransform([
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
        ])
        self.batch_transpose = random_transpose and BatchRandomTranspose()
        self.batch_shift_hue = random_color and BatchRandomHue()
        self.img_or_contour_to_tensor = transforms.ToTensor()
        self.img_or_contour_normalize = legacy_normalize and transforms.Normalize(0.5, 1.0)
        self.as_pil_image = as_pil_image

    def get_icon(self, index):
        return Image.open(self.img_paths[self.keys[index]]).copy()
    
    def get_contour(self, index):
        return ImageOps.invert(Image.open(self.contour_paths[self.keys[index]])).copy()

    def __getitem__(self, index):
        img = self.get_icon(index)
        contour = self.get_contour(index)

        img, contour = self.batch_resized_crop((img, contour))
        if self.batch_transpose:
            img, contour = self.batch_transpose((img, contour))
        if self.batch_shift_hue:
            img = self.batch_shift_hue([img])[0]
        
        if self.as_pil_image:
            return img, contour

        img, contour = map(self.img_or_contour_to_tensor, (img, contour))
        if self.img_or_contour_normalize:
            img, contour = map(self.img_or_contour_normalize, (img, contour))

        return img, contour

    def __len__(self):
        return len(self.keys)


class StylePaletteDataset(Dataset):
    def __init__(self,
                 root,
                 image_size,
                 pickle_folder,
                 max_samples=1000,
                 legacy_normalize=False):
        self.image_size = image_size
        self.dataset = IconContourDataset(root, image_size, split=(0.0, 0.9))
        self.max_samples = max_samples

        cis_pickle_path = os.path.join(pickle_folder, 'color_image_scale.pkl')
        dismat_path = os.path.join(pickle_folder, 'cis_dismat.npz')
        
        assert os.path.exists(cis_pickle_path)
        assert os.path.exists(dismat_path)

        style_info = load_style_lists(cis_pickle_path)
        self.style_names = style_info['name_list']
        self.style_to_cmb = style_info['name_to_cmb']
        self.style_to_pos = style_info['name_to_pos']
        
        dismat = np.load(dismat_path)['dismat'].T
        sorted_dismat = np.sort(dismat, 1)
        threshold = sorted_dismat[:, max_samples].min()
        self.style_refs = {}
        for style_name, dis in zip(self.style_names, dismat):
            labels = np.where(dis < threshold)[0].tolist()
            self.style_refs[style_name] = labels[:max_samples]
        print('reference counts:', list(map(len, self.style_refs.values())))
        print('minimum reference count:', min(map(len, self.style_refs.values())))
        for style_name in self.style_refs:
            self.style_refs[style_name] = self.style_refs[style_name] + [None] * (max_samples - len(self.style_refs))
        
        self.random_resized_crop = transforms.RandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0), T.InterpolationMode.BICUBIC)
        self.random_transpose = RandomTranspose()
        self.to_tensor = transforms.ToTensor()
        self.legacy_normalize = legacy_normalize and transforms.Normalize(0.5, 1.0)
    
    @property
    def condition_size(self):
        return 2
    
    def position_to_condition(self, position, perturb=0.0):
        # map (-3, +3) to (-1, +1)
        position = torch.FloatTensor(position) / 3
        # add noise to make it more dense
        if perturb > 0.0:
            scale = torch.FloatTensor([0.1, 0.05]) * perturb
            noise = torch.randn_like(position) * scale
            position = position + noise
        return position

    def __len__(self):
        return len(self.style_names)
    
    def random_style_ref(self, style_name):
        refs = self.style_refs[style_name]
        idx = random.choice(refs)
        if idx is None:
            ref = self.random_pseudo_ref(style_name)
        else:
            ref = self.dataset.get_icon(idx)
            ref = self.random_resized_crop(ref)
            ref = self.random_transpose(ref)
        ref = self.to_tensor(ref)
        if self.legacy_normalize:
            ref = self.legacy_normalize(ref)
        return ref
    
    def random_pseudo_ref(self, style_name):
        cmb = self.style_to_cmb[style_name]
        image_size = self.image_size
        
        img = Image.new('RGB', (image_size, image_size), (255, 255, 255))
        draw = ImageDraw.Draw(img)

        for i in range(1):
            size = int(image_size * 0.9)
            p = int(image_size * 0.1)

            colors = list(map(tuple, cmb))
            random.shuffle(colors)

            for color in colors:
                size = random.randint(int(size * 0.6), int(size * 0.8))
                free = image_size - size - 2 * p
                x, y = map(random.randrange, (free, free))
                x, y = x + p, y + p

                ex, ey = x + size, y + size

                r = random.randrange(4)
                if r < 2:
                    p = random.randrange(int(size * 0.1))
                    if r == 0:
                        x, ex = x + p, ex - p
                    elif r == 1:
                        y, ey = y + p, ey - p
                if r % 2 == 0:
                    draw.ellipse([(x, y), (ex, ey)], fill=color)
                else:
                    draw.rectangle((x, y, ex, ey), fill=color)
        
        return img
    
    def __getitem__(self, index):
        style_name = self.style_names[index]
        condition = self.position_to_condition(self.style_to_pos[style_name], 0.1)
        reference = self.random_style_ref(style_name)
        return reference, condition


class IconContourDownscaleDataset(Dataset):
    def __init__(self, root, image_size, down_size,
                 random_crop=False, random_transpose=False, random_color=False,
                 split=(0, 1), legacy_normalize=False):
        root = os.path.expanduser(root)
        self.down_size = down_size

        img_dir = os.path.join(root, str(image_size), 'img')
        contour_dir = os.path.join(root, str(image_size), 'contour')

        def get_key(path): return os.path.splitext(os.path.basename(path))[0]

        img_paths = {get_key(path): path for path in glob.glob(
            os.path.join(img_dir, '*.png'))}
        contour_paths = {get_key(path): path for path in glob.glob(
            os.path.join(contour_dir, '*.png'))}

        assert set(img_paths.keys()) == set(contour_paths.keys())
        keys = list(img_paths.keys())
        assert len(keys) > 0
        keys.sort()
        random.Random(1337).shuffle(keys)
        keys = keys[int(len(keys)*split[0]):int(len(keys)*split[1])]
        self.keys = keys

        self.img_paths = {key: img_paths[key] for key in keys}
        self.contour_paths = {key: contour_paths[key] for key in keys}

        self.batch_resized_crop = BatchRandomResizedCrop(
            (image_size, image_size), (0.8, 1.0), (1.0, 1.0),
            (T.InterpolationMode.BICUBIC, T.InterpolationMode.BICUBIC)
        ) if random_crop else TupleTransform([
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
            transforms.Resize((image_size, image_size),
                              T.InterpolationMode.BICUBIC),
        ])
        self.batch_transpose = random_transpose and BatchRandomTranspose()
        self.batch_shift_hue = random_color and BatchRandomHue()
        self.img_or_contour_to_tensor = transforms.ToTensor()
        self.img_or_contour_normalize = legacy_normalize and transforms.Normalize(0.5, 1.0)

    def __getitem__(self, index):
        key = self.keys[index]

        img_path = self.img_paths[key]
        contour_path = self.contour_paths[key]

        img: Image.Image = Image.open(img_path).copy()
        contour: Image.Image = ImageOps.invert(Image.open(contour_path).copy())

        img, contour = self.batch_resized_crop((img, contour))
        if self.batch_transpose:
            img, contour = self.batch_transpose((img, contour))
        if self.batch_shift_hue:
            img = self.batch_shift_hue([img])[0]
        
        original_img: Image.Image = img
        original_contour: Image.Image = contour
        img = img.resize((self.down_size, self.down_size), Image.BICUBIC)
        contour = contour.resize((self.down_size, self.down_size), Image.BICUBIC)

        img, contour, original_img, original_contour = \
            map(self.img_or_contour_to_tensor, (img, contour, original_img, original_contour))
        
        if self.img_or_contour_normalize:
            img, contour, original_img, original_contour = \
                map(self.img_or_contour_normalize, (img, contour, original_img, original_contour))
            
        return img, contour, original_img, original_contour


    def __len__(self):
        return len(self.keys)



if __name__ == '__main__':
    # group_dataset = StyleConditionDataset('datasets/icon4/data', 128, 'datasets/icon4', condition='group')
    # style_dataset = StyleConditionDataset('datasets/icon4/data', 128, 'datasets/icon4', condition='style')
    # position_dataset = StyleConditionDataset('datasets/icon4/data', 128, 'datasets/icon4', condition='position')
    dataset = StylePseudoDataset(
        'datasets/icon4/data',
        128,
        'datasets/icon4',
        'position',
        64,
        1280,
        0.5,
        True
    )
    import ei
    ei.embed()
