import tensorflow as tf
import numpy as np
import cv2, imageio, georeg, os, cosy, math, tfcv
from distinctipy import distinctipy

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker

class Base:
    def __init__(self, model_constants, loss_params):
        self.model_constants = model_constants
        self.head_colors = None

    def loss_fn(self, model_output, model_input, loss_input, schedule_factor, metrics):
        names = [k[len("rotation-"):] for k in model_output.keys() if k.startswith("rotation-")]

        for name in names:
            pred_rotation = model_output[f"rotation-{name}"]
            pred_translation = model_output[f"translation-{name}"]
            pred_bevpixels_to_aerialpixels = georeg.model.util.pad_matrix(
                tf.concat([pred_rotation, pred_translation[:, :, tf.newaxis]], axis=2),
                rank=3,
            )
            pred_aerialpixels_to_bevpixels = tf.linalg.inv(pred_bevpixels_to_aerialpixels)

            gt_bevpixels_to_aerialpixels = loss_input.get_bevpixels_to_aerialpixels_with_shape(
                bev_shape=self.model_constants.bev_shapes[-1],
                aerial_shape=self.model_constants.aerial_shapes[-1],
                model_constants=self.model_constants,
            )
            gt_aerialpixels_to_bevpixels = tf.linalg.inv(gt_bevpixels_to_aerialpixels)

            gt_atb_translation, gt_atb_rotation = gt_aerialpixels_to_bevpixels[..., :2, 2], gt_aerialpixels_to_bevpixels[..., :2, :2]
            pred_atb_translation, pred_atb_rotation = pred_aerialpixels_to_bevpixels[..., :2, 2], pred_aerialpixels_to_bevpixels[..., :2, :2]

            translation_error = gt_atb_translation - pred_atb_translation
            metrics[f"translation-error-{name}"] = tf.norm(translation_error, axis=-1)
            metrics[f"translation-error-lon-{name}"] = tf.math.abs(translation_error[..., 0])
            metrics[f"translation-error-lat-{name}"] = tf.math.abs(translation_error[..., 1])

            angle_error = tf.math.abs(cosy.tf.rotation_matrix_to_angle(pred_atb_rotation) - cosy.tf.rotation_matrix_to_angle(gt_atb_rotation))
            angle_error = tf.where(angle_error > math.pi, 2 * math.pi - angle_error, angle_error)
            metrics[f"angle-error-{name}"] = angle_error / math.pi * 180.0

        return None

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {}
        for layer in model.layers:
            if layer.name.startswith("rotation-") or layer.name.startswith("translation-") or "compete-weights" in layer.name:
                result[layer.name] = layer.output
        return result

    def debug_fn(self, model_output, model_input, loss_input, frames, path, names):
        bev_pixels = model_input.bev_pixels
        if bev_pixels.shape[0] > 0:
            offset = tf.cast(self.model_constants.bev_shapes[-1], bev_pixels.dtype) / 2
            bev_pixels = bev_pixels + tf.cast(offset[tf.newaxis, tf.newaxis, :], bev_pixels.dtype)

            aerial_pixels, aerial_pixels_mask = georeg.model.util.project.bevpixels_to_aerialpixels_onimage(
                bev_shape=self.model_constants.bev_shapes[-1],
                aerial_shape=self.model_constants.aerial_image_shape,
                loss_input=loss_input,
                bev_pixels=bev_pixels,
                model_constants=self.model_constants,
            )
            aerial_pixels = aerial_pixels.numpy()
            aerial_pixels_mask = aerial_pixels_mask.numpy()

        colormap = cv2.applyColorMap(np.asarray([np.arange(256).astype("uint8")]), cv2.COLORMAP_JET)[0, :, ::-1] # [256, 3]
        for b, b_name in zip(range(len(frames)), names):
            with open(os.path.join(path, f"{b_name}-info.txt"), "w") as f:
                f.write(f"Ground frame: {frames[b].ground_frame.dataset_name} {frames[b].ground_frame.location} {frames[b].ground_frame.scene_id} {frames[b].ground_frame.index_in_scene}\n")
                f.write(f"Aerial frame: {frames[b].aerial_frame.name}\n")
                f.write(f"Latlon: {frames[b].ground_frame.latlon}\n")

            color = model_input.aerial_image[b].astype("uint8")

            if bev_pixels.shape[0] > 0:
                aerial_pixels_b = aerial_pixels[b][aerial_pixels_mask[b]]
                # mask = np.all(np.logical_and(0 <= aerial_pixels_b, aerial_pixels_b < self.model_constants.aerial_image_shape[np.newaxis, :]), axis=-1)
                # aerial_pixels_b = aerial_pixels_b[mask]

                def add_points(image, s=0.5):
                    image[aerial_pixels_b[:, 0].astype("int32"), aerial_pixels_b[:, 1].astype("int32")] = \
                        s * image[aerial_pixels_b[:, 0].astype("int32"), aerial_pixels_b[:, 1].astype("int32")] + (1 - s) * np.asarray((255, 242, 0), "float32")

                add_points(color)
            imageio.imwrite(os.path.join(path, f"{b_name}-color.jpg"), color)
