import os, glob
import numpy as np
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from numpy.core.numeric import full
from skimage.feature import canny as get_canny_feature
from collections import Counter
from tqdm import tqdm
import cv2
import json
import shutil
import multiprocessing as mp


def load_image(path):
    raw = Image.open(path)
    if raw.mode in ('P', 'LA'):
        raw = raw.convert('RGBA')
    assert raw.mode in ('RGB', 'RGBA')
    assert raw.size == (512, 512)
    return raw

def pad_image(img):
    n = int(img.size[0] / 8)
    if img.mode == 'RGBA':
        return ImageOps.expand(img, n, (255, 255, 255, 0))
    elif img.mode == 'RGB':
        return ImageOps.expand(img, n, (255, 255, 255))
    assert False, repr(img.mode)
    
def get_contour(img):
    x = np.array(img)
    
    canny = 0
    for layer in np.rollaxis(x, -1):
        canny |= get_canny_feature(layer, 0)
    canny = canny.astype(np.uint8) * 255
    
    kernel = np.array([
        [0, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 0],
    ], dtype=np.uint8)
    
    canny = cv2.dilate(canny, kernel)
    canny = Image.fromarray(canny)
    
    return canny


def get_pair(path):
    img = load_image(path)
    img = pad_image(img)
    contours = get_contour(img)
    return img.convert('RGB'), contours

def preprocess(args):
    raw_path, resolutions, output_sizes, name = args
    full_img, full_contour = get_pair(raw_path)
    
    for resolution, output_size in zip(resolutions, output_sizes):
        (full_img
            .resize((output_size, output_size), Image.BICUBIC)
            .save(f'data/{resolution}/img/{name}.png'))
        (full_contour
            .resize((output_size, output_size), Image.BICUBIC)
            .save(f'data/{resolution}/contour/{name}.png'))

    return raw_path

def build_dataset(
    path=None,
    resolutions=[64, 128, 256, 512],
    pool_size=24
):
    if path is not None:
        os.chdir(path)
    
    assert os.path.isdir('raw')
    assert len(glob.glob(os.path.join('raw', '*.png'))) > 0
    
    resolutions = sorted(list(set(resolutions)))
    assert set(resolutions) <= {64, 128, 256, 512}
    assert len(resolutions) > 0
    print('resolutions:', resolutions)

    assert pool_size > 0 and pool_size <= 64
    print('workers:', pool_size)
    
    for resolution in resolutions:
        shutil.rmtree(f'data/{resolution}/', ignore_errors=True)
        os.makedirs(f'data/{resolution}/img', exist_ok=True)
        os.makedirs(f'data/{resolution}/contour', exist_ok=True)
    
    output_sizes = [
        int(resolution + 2 * resolution / 8)
        for resolution in resolutions
    ]
    
    raw_paths = glob.glob('raw/*.png')
    raw_paths.sort()
    print('number of images:', len(raw_paths))

    with mp.Pool(pool_size) as pool:
        args_list = [
            (raw_path, resolutions, output_sizes, '{:06d}'.format(i))
            for i, raw_path in enumerate(raw_paths)
        ]
        paths_done = pool.imap_unordered(preprocess, args_list)

        key_to_raw = {
            name: os.path.basename(raw_path)
            for raw_path, _, _, name in args_list
        }
        for resolution in resolutions:
            with open(f'data/{resolution}/key_to_raw.json', 'w') as f:
                json.dump(key_to_raw, f)

        with tqdm(paths_done, desc='Preprocessing', total=len(raw_paths)) as pbar:
            for raw_path in pbar:
                pbar.set_postfix_str(raw_path)
            

if __name__ == '__main__':
    import fire
    
    fire.Fire(build_dataset)
