import sklearn.decomposition, sklearn.cluster
import numpy as np
import tensorflow as tf
import imageio, os, cv2, tfcv, georeg
from distinctipy import distinctipy

def save_embedding_pca(path, name, embedding, image=None, bev_pixels_for_aerial=None, bev_pixels_for_aerial_mask=None, mask=None, components=6):
    colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]

    pca = sklearn.decomposition.PCA(n_components=min(components, embedding.shape[-1]))

    masked_embedding = tf.reshape(embedding, [-1, embedding.shape[-1]])
    if not mask is None:
        masked_embedding = tf.boolean_mask(masked_embedding, tf.reshape(mask, [-1]), axis=0)
    pca.fit(masked_embedding.numpy())

    features_on_bev = tfcv.model.einops.apply("s... f -> (s...) f", embedding)
    features_on_bev = pca.transform(features_on_bev.numpy())
    features_on_bev = tf.convert_to_tensor(tfcv.model.einops.apply("(s...) f -> s... f", features_on_bev, s=tf.shape(embedding)[-3:-1])).numpy()
    mask_on_bev = mask

    if not bev_pixels_for_aerial is None:
        bev_pixels_for_aerial = tf.where(bev_pixels_for_aerial_mask[..., tf.newaxis], bev_pixels_for_aerial, 0)
        features_on_aerial = tf.gather_nd(features_on_bev, tf.cast(bev_pixels_for_aerial, "int32"), batch_dims=0)
        features_on_aerial = tfcv.model.einops.apply("(s...) f -> s... f", features_on_aerial, s=tf.shape(image)[:2]).numpy()

        if not mask is None:
            mask_on_aerial = tf.gather_nd(mask[..., tf.newaxis], tf.cast(bev_pixels_for_aerial, "int32"), batch_dims=0)[..., 0]
            mask_on_aerial = tf.math.logical_and(
                mask_on_aerial,
                bev_pixels_for_aerial_mask,
            )
            mask_on_aerial = tfcv.model.einops.apply("(s...) -> s...", mask_on_aerial, s=tf.shape(image)[:2])
        else:
            mask_on_aerial = None

    s = 0.6
    for e in range(features_on_bev.shape[-1]):
        features_on_bev_e = features_on_bev[:, :, e]

        if not bev_pixels_for_aerial is None:
            features_on_aerial_e = features_on_aerial[:, :, e]

            color = colormap[(georeg.model.util.rescale(features_on_aerial_e) * 255.0).astype("uint8")].astype("float32")
            if not mask_on_aerial is None:
                color = np.where(mask_on_aerial[..., np.newaxis].numpy(), color, 0)
            if image is None:
                color = color.astype("uint8")
            else:
                color = (image * s + (1 - s) * color).astype("uint8")
            imageio.imwrite(os.path.join(path, f"{name}-pca{e}-on-aerial_min{np.amin(features_on_aerial_e)}_max{np.amax(features_on_aerial_e)}.png"), color)
        else:
            color = colormap[(georeg.model.util.rescale(features_on_bev_e) * 255.0).astype("uint8")].astype("float32")
            if not mask_on_bev is None:
                color = np.where(mask_on_bev[..., np.newaxis].numpy(), color, 0)
            if image is None or not bev_pixels_for_aerial is None:
                color = color.astype("uint8")
            else:
                color = (image * s + (1 - s) * color).astype("uint8")
            imageio.imwrite(os.path.join(path, f"{name}-pca{e}_min{np.amin(features_on_bev_e)}_max{np.amax(features_on_bev_e)}.png"), color)


# def save_clusters(path, name, embedding, image=None, mask=None):
#     class_to_color = np.asarray(distinctipy.get_colors(16)) * 255.0
#
#     masked_embedding = tf.reshape(embedding, [-1, embedding.shape[-1]])
#     if not mask is None:
#         masked_embedding = tf.boolean_mask(masked_embedding, tf.reshape(mask, [-1]), axis=0)
#     masked_embedding = masked_embedding.numpy()
#
#     embedding_flat = tf.reshape(embedding, [-1, embedding.shape[-1]]).numpy()
#
#     for classes_num in range(2, 6):
#         algo = sklearn.cluster.KMeans(
#             n_clusters=classes_num,
#             n_init=10,
#             max_iter=40,
#             tol=1e-3,
#             algorithm="elkan",
#         )
#         algo.fit(masked_embedding[::100])
#
#         classes = np.argmin(algo.transform(embedding_flat), axis=-1)
#         classes = classes.reshape([embedding.shape[0], embedding.shape[1]])
#         color = tfcv.image.class_to_color(classes, image=image, classes_num=classes_num, class_alpha=0.3, class_to_color=class_to_color[:classes_num])
#
#         if not mask is None:
#             color = np.where(mask[..., tf.newaxis].numpy(), color, 0)
#         imageio.imwrite(os.path.join(path, f"{name}-cluster{classes_num}.png"), color)

def save_weights():
    aerial_embed_pixels = georeg.model.util.project.generate_bev_pixels(model_input.batches, self.model_constants.aerial_embed_shape)
    bev_embed_pixels, bev_embed_pixels_mask = georeg.model.util.project.aerialembedpixels_to_bevembedpixels_onimage(self.model_constants, loss_input, aerial_embed_pixels)
    bev_embed_pixels = tf.where(bev_embed_pixels_mask[..., tf.newaxis], bev_embed_pixels, 0)
    bev_embed_pixels_mask = tf.math.logical_and(
        bev_embed_pixels_mask,
        tf.gather_nd(tf.math.reduce_any(bev_embedding_mask, axis=1)[..., tf.newaxis], tf.cast(bev_embed_pixels, "int32"), batch_dims=1)[..., 0]
    )

    for h in range(corr_heads):
        image = tf.where(bev_embedding_mask[b, h, ..., tf.newaxis], 255, 0)
        imageio.imwrite(os.path.join(path, f"{b_name}-bev-mask-h{h}.jpg"), image.numpy().astype("uint8"))

        if "aerial/embedding-mask-no-rotation" in model_output:
            image = tf.where(model_output["aerial/embedding-mask-no-rotation"][b, h, ..., tf.newaxis], 255, 0)
            imageio.imwrite(os.path.join(path, f"{b_name}-aerial-mask-h{h}.jpg"), image.numpy().astype("uint8"))

    aerial_head = tf.argmax(tf.linalg.norm(aerial_embedding[b], axis=-1), axis=0).numpy()
    aerial_head_color = self.head_colors[aerial_head]
    color = (aerial_head_color * t + (1 - t) * aerial_image_b.astype("uint8")).astype("uint8")
    imageio.imwrite(os.path.join(path, f"{b_name}-aerial-weight-maxhead.jpg"), color)

    bev_head = tf.argmax(tf.linalg.norm(bev_embedding[b], axis=-1), axis=0)
    bev_head = tf.gather_nd(bev_head[..., tf.newaxis], tf.cast(bev_embed_pixels[b], "int32"), batch_dims=0)[..., 0]
    bev_head = np.where(bev_embed_pixels_mask[b].numpy(), bev_head, 0)
    bev_head = tfcv.model.einops.apply("(s...) -> s...", bev_head, s=self.model_constants.aerial_embed_shape).numpy()
    bev_head_color = self.head_colors[bev_head]
    bev_head_color = np.where(bev_embed_pixels_mask[b].numpy().reshape(self.model_constants.aerial_embed_shape)[..., np.newaxis], bev_head_color, 0)
    color = (bev_head_color * t + (1 - t) * aerial_image_b.astype("uint8")).astype("uint8")
    imageio.imwrite(os.path.join(path, f"{b_name}-bev-weight-maxhead.jpg"), color)

    def save_weight(name, embedding, image=None, bev_pixels_for_aerial=None, bev_pixels_for_aerial_mask=None):
        histograms = []
        w = np.linalg.norm(embedding.numpy(), axis=-1)
        hist_range = (np.amin(w), np.amax(w))
        hist_bins = 50

        for h in range(corr_heads):
            w = np.linalg.norm(embedding[h].numpy(), axis=-1)

            if not bev_pixels_for_aerial is None:
                w = tf.gather_nd(w[..., tf.newaxis], tf.cast(bev_pixels_for_aerial, "int32"), batch_dims=0)
                w = tf.where(bev_pixels_for_aerial_mask[..., tf.newaxis], w, 0.0)
                w = tfcv.model.einops.apply("(s...) f -> s... f", w, s=tf.shape(image)[:2]).numpy()[..., 0]

            color = georeg.model.util.rescale(w)
            color = colormap[(color * 255.0).astype("uint8")]
            if not image is None:
                color = (color * t + (1 - t) * image).astype("uint8")

            imageio.imwrite(os.path.join(path, f"{b_name}-{name}-h{h}.jpg"), color)

            histogram, bin_edges = np.histogram(w, hist_bins, hist_range)
            bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
            histograms.append(histogram)

        w = np.sum(np.linalg.norm(embedding.numpy(), axis=-1), axis=0)

        if not bev_pixels_for_aerial is None:
            w = tf.gather_nd(w[..., tf.newaxis], tf.cast(bev_pixels_for_aerial, "int32"), batch_dims=0)
            w = tf.where(bev_pixels_for_aerial_mask[..., tf.newaxis], w, 0.0)
            w = tfcv.model.einops.apply("(s...) f -> s... f", w, s=tf.shape(image)[:2]).numpy()[..., 0]

        color = georeg.model.util.rescale(w)
        color = colormap[(color * 255.0).astype("uint8")]
        if not image is None:
            color = (color * t + (1 - t) * image).astype("uint8")

        imageio.imwrite(os.path.join(path, f"{b_name}-{name}-hsum.jpg"), color)

        plt.figure(1)
        plt.xlabel("Weight")
        plt.ylabel("Density")
        for head, histogram in enumerate(histograms):
            color = tuple((self.head_colors[head] / 255.0).tolist()) # tuple((colormap[int(float(head) / corr_heads * 255)].astype("float32") / 255.0).tolist())
            plt.plot(bin_centers, histogram / np.sum(histogram), color=color, label=f"h{head}")
        plt.legend()
        plt.savefig(os.path.join(path, f"{b_name}-{name}-weighthistogram.jpg"), dpi=300)
        plt.close(1)
    save_weight("aerial-weight", aerial_embedding[b], image=aerial_image_b.astype("uint8"))
    save_weight("bev-weight", bev_embedding[b], image=aerial_image_b.astype("uint8"), bev_pixels_for_aerial=bev_embed_pixels[b], bev_pixels_for_aerial_mask=bev_embed_pixels_mask[b])


class SaveEmbedding:
    def __init__(self, model_constants):
        self.model_constants = model_constants
        self.head_colors = None

    def get_output(self, model, model_input, loss_input, config, model_params):
        return {
            "aerial/embedding-no-rotation": model.get_layer("aerial/embedding-no-rotation").output,
            "bev/embedding-no-rotation": model.get_layer("bev/embedding-no-rotation").output,
            "bev/embedding-mask-no-rotation": model.get_layer("bev/embedding-mask-no-rotation").output,
        }

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {}

        layer_names = [l.name for l in model.layers]

        i = 1
        while True:
            if f"bev/embedding-no-rotation-{i}" in layer_names:
                result[f"bev/embedding-no-rotation-{i}"] = model.get_layer(f"bev/embedding-no-rotation-{i}").output
                result[f"aerial/embedding-no-rotation-{i}"] = model.get_layer(f"aerial/embedding-no-rotation-{i}").output
                result[f"bev/embedding-mask-no-rotation-{i}"] = model.get_layer(f"bev/embedding-mask-no-rotation-{i}").output
                i += 1
            else:
                break
        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        def save(bev_embedding, bev_embedding_mask, aerial_embedding, name):
            # bev: b h s...

            heads = aerial_embedding.shape[1]
            if self.head_colors is None:
                self.head_colors = np.asarray(distinctipy.get_colors(heads)) * 255.0

            aerial_embed_pixels = georeg.model.util.project.generate_bev_pixels(model_input.batches, self.model_constants.aerial_embed_shape)
            bev_embed_pixels, mask = georeg.model.util.project.aerialembedpixels_to_bevembedpixels_onimage(self.model_constants, loss_input, aerial_embed_pixels)

            for b, b_name in zip(range(len(frames)), names):
                aerial_image = model_input.aerial_image[b]
                aerial_image = tfcv.image.transform.resize_to(aerial_embedding.shape[2:4], ndim=2)((aerial_image, "color"))

                path2 = os.path.join(path, b_name)
                if not os.path.isdir(path2):
                    os.makedirs(path2)

                for h in range(heads):
                    save_embedding_pca(path2, f"{name}-h{h}-aerial", aerial_embedding[b, h], image=aerial_image)
                    save_embedding_pca(path2, f"{name}-h{h}-bev", bev_embedding[b, h], image=aerial_image, mask=bev_embedding_mask[b, h], bev_pixels_for_aerial=bev_embed_pixels[b], bev_pixels_for_aerial_mask=mask[b])

        i = 1
        while True:
            if f"bev/embedding-no-rotation-{i}" in model_output:
                save(model_output[f"bev/embedding-no-rotation-{i}"], model_output[f"bev/embedding-mask-no-rotation-{i}"], model_output[f"aerial/embedding-no-rotation-{i}"], name=f"block{i}")
                i += 1
            else:
                break
