import math, yaml, skimage.filters, skimage.color
import numpy as np
import tensorflow as tf
from distinctipy import distinctipy
from . import param, op
# TODO: some of these are with tensorflow, others with numpy

def op_colorspace(inner, colorspace_in="rgb", colorspace="rgb", colorspace_out="rgb"):
    def get_transform(colorspace_in, colorspace_out):
        if colorspace_in != colorspace_out:
            name = f"{colorspace_in}2{colorspace_out}"
            if name not in vars(skimage.color):
                raise ValueError(f"No color conversion method from {colorspace_in} to {colorspace_out}")
            return lambda x: vars(skimage.color)[name](x.astype("float32") / 255.0).astype("float32") * 255.0
        else:
            return lambda image: image

    transform_input = get_transform(colorspace_in, colorspace)
    transform_output = get_transform(colorspace, colorspace_out)

    def outer(images, types):
        images = [transform_input(i) for i in images]
        images = inner(images, types)
        images = [transform_output(i) for i in images]
        return images

    return op.op(outer, type="color")

def per_image(func, out_of_bounds="clip", *args, **kwargs):
    def call(images, types):
        for i in range(len(images)):
            images[i] = func(images[i].astype("float32"), types[i])
            if out_of_bounds == "clip":
                images[i] = np.clip(images[i], 0.0, 255.0)
            elif out_of_bounds == "repeat":
                images[i] = np.fmod(images[i], 256.0)
                images[i] = np.fmod(images[i] + 256.0, 256.0)
            else:
                raise ValueError(f"Invalid out_of_bounds argument {out_of_bounds}")
        return images

    return op_colorspace(call, *args, **kwargs)

def add(offset, ndim=2, *args, **kwargs):
    def func(image, type, offset=offset):
        offset = param.array(offset, name="offset", dtype=np.float32)
        return image + offset
    return per_image(func, *args, **kwargs)

def multiply(factor, ndim=2, *args, **kwargs):
    def func(image, type, factor=factor):
        factor = param.array(factor, name="factor", dtype=np.float32)
        return image * factor
    return per_image(func, *args, **kwargs)

def add_gaussian_noise(std, ndim=2, *args, **kwargs):
    def func(image, type, std=std):
        std = param.array(std, name="std", dtype=np.float32, shape=())
        shape = np.asarray(image.shape)
        if shape.shape[0] != ndim:
            shape[-1] = 1
        return image + np.random.normal(
            loc=0.0,
            scale=std,
            size=shape
        )
    return per_image(func, *args, **kwargs)

def normalize_color(src, dest, *args, **kwargs):
    if isinstance(src, str):
        with open(src, "r") as f:
            src = yaml.safe_load(f)
    if isinstance(dest, str):
        with open(dest, "r") as f:
            dest = yaml.safe_load(f)
    src_mean = np.asarray(src["mean"])
    std_factor = np.asarray(dest["std"]) / np.asarray(src["std"])
    dest_mean = np.asarray(dest["mean"])

    a = std_factor
    b = dest_mean - std_factor * src_mean

    def func(image, type):
        return a * image + b

    return per_image(func, *args, **kwargs)

# amount > 0 => sharpen, amount < 0 => blur
def blur_sharpen(amount, std, ndim=2, *args, **kwargs):
    def func(image, type2, amount=amount, std=std):
        std = param.array(std, name="std", dtype=np.float32, shape=())
        amount = param.array(amount, name="amount", dtype=np.float32, shape=())

        g = skimage.filters.gaussian(image, sigma=float(std), multichannel=ndim != image.ndim, mode="nearest")
        image = image + (image - g) * amount

        return image

    return per_image(func, *args, **kwargs)

def color_to_class(color, color_to_class=None, class_to_color=None, find_closest=False):
    if color.shape[-1] != 3:
        raise ValueError(f"Expects a 3-channel input, got image shape {color.shape}")
    if (color_to_class is None) == (class_to_color is None):
        raise ValueError(f"Expected exactly one of color_to_class and class_to_color")
    if color_to_class is None:
        color_to_class = {color: c for c, color in enumerate(class_to_color)}
    if class_to_color is None:
        class_to_color = {c: color for color, c in color_to_class.items()}
        classes_num = np.max(np.array(list(color_to_class.values()))) + 1
        class_to_color = [[class_to_color[c] if c in class_to_color else (0, 0, 0)] for c in range(classes_num)]

    if not find_closest:
        original_shape = color.shape
        color = color.astype("int32")
        color_ids = color[:, :, 2] * 256 * 256 + color[:, :, 1] * 256 + color[:, :, 0]
        colorid_to_id = {(c[2] * 256 * 256 + c[1] * 256 + c[0]): id for c, id in color_to_class.items()}

        unique_color_ids, inv_mapping = np.unique(color_ids, return_inverse=True)
        unique_ids = np.array([colorid_to_id[colorid] for colorid in unique_color_ids])

        labels = unique_ids[inv_mapping]
        labels = np.reshape(labels, (original_shape[0], original_shape[1])).astype("uint8")
    else:
        color_diff = color[:, :, np.newaxis, :] - np.asarray(class_to_color)[np.newaxis, np.newaxis, :, :] # [row, col, class, color]
        color_diff = np.linalg.norm(color_diff, axis=-1) # [row, col, class]
        labels = np.argmin(color_diff, axis=-1).astype("uint8")

    return labels

def class_to_color(segmentation, image=None, color_to_class=None, class_to_color=None, classes_num=None, dont_care_color=(0, 0, 0), dont_care_threshold=0.5, class_alpha=0.5):
    if len(segmentation.shape) == 3:
        dont_care = tf.reduce_sum(segmentation, axis=-1) < dont_care_threshold
        segmentation = tf.where(dont_care, -1, tf.argmax(segmentation, axis=-1))

    if color_to_class is None and class_to_color is None:
        if classes_num is None:
            classes_num = tf.math.reduce_max(segmentation).numpy() + 1
        class_to_color = np.asarray(distinctipy.get_colors(classes_num)) * 255.0
        color_to_class = {tuple(color.tolist()): c for c, color in enumerate(class_to_color)}
    elif color_to_class is None and not class_to_color is None:
        color_to_class = {tuple(color if not isinstance(color, np.ndarray) else color.tolist()): c for c, color in enumerate(class_to_color)}
        classes_num = len(class_to_color)
    elif class_to_color is None and not color_to_class is None:
        class_to_color = {c: color for color, c in color_to_class.items()}
        classes_num = np.max(np.array(list(color_to_class.values()))) + 1
        class_to_color = [[class_to_color[c] if c in class_to_color else (0, 0, 0)] for c in range(classes_num)]
    else:
        raise ValueError(f"Cannot pass both of color_to_class and class_to_color")
    if classes_num != len(class_to_color):
        raise ValueError(f"Number of classes {classes_num} does not match class_to_color")
    if classes_num != len(color_to_class):
        print(f"WARNING: Got non-unqiue colors in color-class mapping ({len(class_to_color)} classes, {len(color_to_class)} colors)")

    # Add dont-care class at the end
    dont_care = tf.logical_or(segmentation >= len(class_to_color), segmentation < 0)
    class_to_color = tf.concat([class_to_color, [dont_care_color]], axis=0)
    segmentation_rgb = tf.where(dont_care, len(class_to_color) - 1, segmentation)

    # Gather segmentation colors
    segmentation_rgb = tf.reshape(segmentation_rgb, [-1])
    segmentation_rgb = tf.cast(segmentation_rgb, "int32")
    segmentation_rgb = tf.gather(class_to_color, segmentation_rgb)
    segmentation_rgb = tf.reshape(segmentation_rgb, tf.concat([tf.shape(segmentation), tf.shape(segmentation_rgb)[-1:]], axis=0))
    segmentation_rgb = tf.cast(segmentation_rgb, dtype=tf.uint8)

    # Overlay with color image
    if not image is None:
        image = (1.0 - class_alpha) * tf.cast(image, "float32") + class_alpha * tf.cast(segmentation_rgb, "float32")
        image = tf.clip_by_value(image, 0.0, 255.0)
    else:
        image = segmentation_rgb

    return tf.cast(image, "uint8")
