import tensorflow as tf
import tensorflow_addons as tfa
import tfcv, georeg, math, cv2, imageio, os, sys, yaml, cosy, re
from functools import partial
import numpy as np

class ModelConstants:
    def __init__(self, bev_final_shape, aerial_final_shape, final_meters_per_pixel, ground_attn_strides, aerial_attn_strides, aerial_stride, max_cameras):
        assert len(ground_attn_strides) == len(aerial_attn_strides)
        assert aerial_attn_strides[-1] == 1

        self.ground_attn_strides = np.asarray(ground_attn_strides)
        self.aerial_attn_strides = np.asarray(aerial_attn_strides)

        self.aerial_stride = aerial_stride

        self.bev_final_shape = np.asarray(bev_final_shape)
        assert all([np.all(self.bev_final_shape // s * s == self.bev_final_shape) for s in self.aerial_attn_strides])
        self.bev_shapes = np.asarray([self.bev_final_shape // s for s in self.aerial_attn_strides])

        aerial_final_shape = np.asarray(aerial_final_shape)
        assert all([np.all(aerial_final_shape // s * s == aerial_final_shape) for s in self.aerial_attn_strides])
        self.aerial_shapes = np.asarray([aerial_final_shape // s for s in self.aerial_attn_strides])

        self.meters_per_pixel = np.asarray([final_meters_per_pixel * s for s in self.aerial_attn_strides])

        self.max_cameras = max_cameras

    aerial_image_shape = property(lambda self: self.aerial_shapes[-1] * self.aerial_stride)
    blocks = property(lambda self: len(self.aerial_attn_strides))

def build(model_constants, variant, model_params):
    # norm_name, norm = ("BN", lambda x, **kwargs: tf.keras.layers.BatchNormalization(**kwargs)(x))
    norm_name, norm = ("LN", lambda x, **kwargs: tf.keras.layers.LayerNormalization(**kwargs)(x))
    # norm_name, norm = ("GN16", lambda x, **kwargs: tfa.layers.GroupNormalization(groups=16, **kwargs)(x))

    # act_name, act = ("ReLU", lambda x, **kwargs: tf.keras.layers.Activation(tf.keras.activations.relu, **kwargs)(x))
    act_name, act = ("GELU", lambda x, **kwargs: tf.keras.layers.Activation(tf.keras.activations.gelu, **kwargs)(x))

    config = tfcv.model.config.Config(
        norm=norm,
        act=act,
        resize=tfcv.model.config.partial_with_default_args(tfcv.model.config.resize, align_corners=False),
    )

    model_params["norm"] = norm_name
    model_params["act"] = act_name



    model_input = georeg.model.io.ModelInput.keras()

    corr_logits, valid_corr, preprocess_aerial, preprocess_ground = variant.predict(
        model_input,
        model_params=model_params,
        config=config,
    )

    corr = georeg.model.correlation.math.softmax(corr_logits, valid_corr)


    outputs = []
    ops = [
        ("discrete", georeg.model.correlation.math.corr_argmax_discrete),
        # ("weighted", georeg.model.correlation.math.corr_argmax_weighted),
    ]
    for s in range(3, 15):
        # ops.append((f"paraboloid2-w{s}", partial(georeg.model.correlation.math.corr_argmax_paraboloid_2d, window_shape=s)))
        for a in range(3, 4):
            # ops.append((f"paraboloid3-w{a}.{s}", partial(georeg.model.correlation.math.corr_argmax_paraboloid_3d, window_shape=(a, s, s))))
            ops.append((f"weighted-w{a}.{s}", partial(georeg.model.correlation.math.corr_argmax_weighted_window, window_shape=(a, s, s))))
    for name, func in ops:
        axy = func(corr, valid_corr, model_input.angles)

        pred_rotation = cosy.tf.angle_to_rotation_matrix(axy[:, 0]) # b 2 2
        pred_translation = axy[:, 1:] # b 2

        pred_rotation = tfcv.model.util.set_name(pred_rotation, f"rotation-{name}")
        pred_translation = tfcv.model.util.set_name(pred_translation, f"translation-{name}")

        outputs.append(pred_rotation)
        outputs.append(pred_translation)

    model = tf.keras.Model(
        inputs=model_input.to_list(),
        outputs=outputs,
    )




    train_modules = variant.get_modules(model_constants, model_params["loss"]) + [m for m in [
        georeg.model.module.Base(model_constants, model_params["loss"]),
        georeg.model.module.CorrelationLoss("corr", model_constants, model_params["loss"]) if "LOSS_CORR_WEIGHT" in os.environ else None,
        # georeg.model.module.SaveEmbedding(model_constants),
    ] if not m is None]

    loss_input = georeg.model.io.LossInput.keras()

    train_output = {}
    train_model = model
    for m in train_modules:
        train_output.update(m.get_output(train_model, model_input, loss_input, config, model_params))
        train_output_names = list(train_output.keys())

        train_model = tf.keras.Model(
            inputs=model_input.to_list() + loss_input.to_list(),
            outputs=[train_output[n] for n in train_output_names] + model.outputs,
        )

    def loss_fn(model_input, loss_input, schedule_factor, training):
        metrics = {}

        train_output = train_model(model_input.to_list(preprocess_aerial=preprocess_aerial, preprocess_ground=preprocess_ground) + loss_input.to_list(), training=training)
        train_output = {k: v for k, v in zip(train_output_names, train_output)}

        valid_sample = True
        loss = 0.0
        for m in train_modules:
            if hasattr(m, "loss_fn"):
                m_args = m.loss_fn(train_output, model_input, loss_input, schedule_factor, metrics)
                if isinstance(m_args, tuple):
                    m_loss, m_valid_sample = m_args
                    valid_sample = tf.math.logical_and(valid_sample, m_valid_sample)
                else:
                    m_loss = m_args
                if not m_loss is None:
                    loss = loss + m_loss



        # loss = tf.where(valid_sample[tf.newaxis], loss, tf.zeros_like(loss))

        return loss, metrics, valid_sample

    def debug_fn(frames, model_input, loss_input, path, names):
        train_output = train_model(model_input.to_list(preprocess_aerial=preprocess_aerial, preprocess_ground=preprocess_ground) + loss_input.to_list(), training=False)
        train_output = {k: v for k, v in zip(train_output_names, train_output)}

        for m in train_modules:
            if hasattr(m, "debug_fn"):
                m.debug_fn(train_output, model_input, loss_input, frames, path, names)

    return model, train_model, preprocess_aerial, preprocess_ground, loss_fn, debug_fn
