import tensorflow as tf
import numpy as np
import tensorflow_addons as tfa
import math, os, tfcv, georeg

class TransformerBlock:
    def __init__(self, model_input, model_constants, model_params):
        self.model_constants = model_constants
        self.model_input = model_input

        self.heads = int(os.environ["PP_FC_HEADS"]) if "PP_FC_HEADS" in os.environ else 1
        model_params["heads"] = self.heads
        self.filters = int(os.environ["PP_FC_FILTERS"]) if "PP_FC_FILTERS" in os.environ else 8
        model_params["filters"] = self.filters

        self.shortcut_logits = int(os.environ["PP_FC_SHORTCUT_LOGITS"]) == 1 if "PP_FC_SHORTCUT_LOGITS" in os.environ else True
        model_params["shortcut-logits"] = self.shortcut_logits

        self.learn_scale = int(os.environ["PP_FC_LEARN_SCALE"]) == 1 if "PP_FC_LEARN_SCALE" in os.environ else True
        model_params["learn-scale"] = self.learn_scale

        self.variance_correction = int(os.environ["PP_FC_VARIANCE_CORRECT"]) == 1 if "PP_FC_VARIANCE_CORRECT" in os.environ else True
        model_params["variance_correction"] = self.variance_correction

        self.rotate_type = str(os.environ["ROTATE_TYPE"]) if "ROTATE_TYPE" in os.environ else "bilinear"
        model_params["rotate_type"] = self.rotate_type

        self.corr_resize_type = str(os.environ["CORR_RESIZE_TYPE"]) if "CORR_RESIZE_TYPE" in os.environ else "logit"
        model_params["corr_resize_type"] = self.corr_resize_type
        assert self.corr_resize_type in ["prob", "logit"]

        self.last_logits = None

    def __call__(self, bev, bev_mask, aerial, corr_shape, name, iteration, config):
        print("Aerial attention block")
        if bev_mask is None:
            bev_mask = tf.ones(tf.shape(bev)[:-1], dtype="bool")

        aerial = tf.ensure_shape(aerial, [None, self.model_constants.aerial_shapes[iteration][0], self.model_constants.aerial_shapes[iteration][1], None])
        bev = tf.ensure_shape(bev, [None, self.model_constants.bev_shapes[iteration][0], self.model_constants.bev_shapes[iteration][1], None])

        bev_orig = bev

        bev = tfcv.model.util.norm(bev, name=tfcv.model.util.join(name, "norm"), config=config)

        bev_shape = tf.shape(bev)[-3:-1]

        # Predict correlation inputs
        bev_into_corr = tfcv.model.util.conv(bev, filters=self.filters, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "bev-query"), config=config)
        aerial_into_corr = tfcv.model.util.conv(aerial, filters=self.filters, kernel_size=1, stride=1, bias=True, name=tfcv.model.util.join(name, "aerial-key"), config=config)

        bev_into_corr = tfcv.model.einops.apply("b s... (h f) -> b h s... f", bev_into_corr, h=self.heads)
        aerial_into_corr = tfcv.model.einops.apply("b s... (h f) -> b h s... f", aerial_into_corr, h=self.heads)
        bev_mask = tfcv.model.einops.apply("b s... -> b h s...", bev_mask, h=self.heads)

        bev_mask = tfcv.model.util.set_name(bev_mask, f"bev/embedding-mask-no-rotation-{iteration + 1}")
        bev_into_corr = tf.where(bev_mask[..., tf.newaxis], bev_into_corr, 0.0)
        bev_into_corr = tfcv.model.util.set_name(bev_into_corr, f"bev/embedding-no-rotation-{iteration + 1}")

        aerial_into_corr = tfcv.model.util.set_name(aerial_into_corr, f"aerial/embedding-no-rotation-{iteration + 1}")

        # Rotate inputs
        bev_into_corr = tfcv.model.einops.apply("b h s... f -> (b h a) s... f", bev_into_corr, a=tf.shape(self.model_input.angles)[1])
        aerial_into_corr = tfcv.model.einops.apply("b h s... f -> (b h a) s... f", aerial_into_corr, a=tf.shape(self.model_input.angles)[1])
        angles = tfcv.model.einops.apply("b a -> (b h a)", self.model_input.angles, h=self.heads)
        def rotate(input, rotate_type=self.rotate_type):
            image, angles = input
            return tfa.image.rotate(image, angles, interpolation=rotate_type, fill_mode="constant", fill_value=0.0)
        bev_into_corr = tf.keras.layers.Lambda(rotate)([bev_into_corr, angles])

        # Pad bev and aerial images
        aerial_shape = tf.cast(tf.shape(aerial_into_corr)[-3:-1], "int64")
        target_shape = aerial_shape

        pad_size = target_shape - tf.cast(bev_shape, "int64")
        front = pad_size // 2
        back = pad_size - front
        paddings = [[0, 0], [front[0], back[0]], [front[1], back[1]], [0, 0]]
        bev_into_corr = tf.pad(bev_into_corr, paddings, mode="CONSTANT", constant_values=0.0)

        # pad_size = target_shape - aerial_shape
        # front = pad_size // 2
        # back = pad_size - front
        # paddings = [[0, 0], [front[0], back[0]], [front[1], back[1]], [0, 0]]
        # aerial_into_corr = tf.pad(aerial_into_corr, paddings, mode="CONSTANT", constant_values=0.0)

        bev_into_corr = tfcv.model.einops.apply("(b h a) s... f -> b h a s... f", bev_into_corr, h=self.heads, a=tf.shape(self.model_input.angles)[1])
        aerial_into_corr = tfcv.model.einops.apply("(b h a) s... f -> b h a s... f", aerial_into_corr, h=self.heads, a=tf.shape(self.model_input.angles)[1])

        # Perform correlation
        corr = tf.keras.layers.Lambda(lambda x: georeg.model.correlation.math.cross_correlate_2d(x[0], x[1], method="fft"))([bev_into_corr, aerial_into_corr]) # b h a s... f
        corr = tf.ensure_shape(corr, [None, self.heads, None, None, None, self.filters // self.heads])

        # Crop to corr_shape
        pad_size = tf.cast(tf.shape(corr)[-3:-1], "int64") - tf.cast(corr_shape, "int64")
        front = pad_size // 2
        back = tf.cast(tf.shape(corr)[-3:-1], "int64") - (pad_size - front)
        corr = corr[:, :, :, front[0]:back[0], front[1]:back[1], :]

        corr = tf.reduce_sum(corr, axis=-1) # b h a s...

        if self.variance_correction:
            variance = tfcv.model.einops.apply("b h s... -> b h", bev_mask, reduction="count_nonzero") * (self.filters // self.heads)
            corr = corr * (tf.cast(variance[:, :, tf.newaxis, tf.newaxis, tf.newaxis], "float32") ** -0.5)

        if self.shortcut_logits:
            if not self.last_logits is None:
                last_logits = self.last_logits
                if self.learn_scale:
                    last_logits = tfcv.model.util.ScaleLayer(axis=[], initial_value=1.0, name=tfcv.model.util.join(name, "shortcut-logits-scale"))(last_logits)

                b = tf.shape(last_logits)[0]
                h = tf.shape(last_logits)[1]
                a = tf.shape(last_logits)[2]
                last_logits = tfcv.model.einops.apply("b h a s... -> (b h a) s... 1", last_logits)
                # last_logits = tf.cast(last_logits, "float64")
                if self.corr_resize_type == "prob":
                    min_logit = tf.stop_gradient(tfcv.model.einops.apply("bha s... 1 -> bha 1... 1", last_logits, output_ndims=4, reduction="min"))
                    numerical_offset = -tfcv.model.einops.apply("bha s... 1 -> bha 1... 1", last_logits, output_ndims=4, reduction="max") + 80.0
                    last_logits = tf.math.exp(last_logits + numerical_offset)
                    # def func(input):
                    #     last_logits, = input
                    #     tf.debugging.assert_all_finite(last_logits, "A1")
                    #     tf.debugging.assert_greater(last_logits, 0.0, "A2")
                    #     return last_logits
                    # last_logits = tf.keras.layers.Lambda(func)([last_logits])
                last_logits = tfcv.model.util.resize(last_logits, tf.shape(corr)[-2:], method="bilinear", config=config)
                if self.corr_resize_type == "prob":
                    # def func(input):
                    #     last_logits, = input
                    #     tf.debugging.assert_all_finite(last_logits, "C1")
                    #     tf.debugging.assert_greater(last_logits, 0.0, "C2")
                    #     return last_logits
                    # last_logits = tf.keras.layers.Lambda(func)([last_logits])
                    last_logits = tf.where(last_logits > 0, tf.math.log(last_logits) - numerical_offset, min_logit)
                # last_logits = tf.cast(last_logits, "float32")
                last_logits = tfcv.model.einops.apply("(b h a) s... 1 -> b h a s...", last_logits, b=b, h=h, a=a)

                corr = corr + last_logits
            self.last_logits = corr

            variance = (iteration + 1)
            corr = corr * (tf.cast(variance, "float32") ** -0.5)

        if self.learn_scale:
            corr = tfcv.model.util.ScaleLayer(axis=[], initial_value=0.1, name=tfcv.model.util.join(name, "logits-scale"))(corr)
        else:
            corr = corr * 0.1

        # Compute corr mask
        xs = tf.cast(tf.range(self.model_input.corr_shape[0]) - self.model_input.corr_shape[0] // 2, corr.dtype) / tf.cast(self.model_input.corr_shape[0] // 2, corr.dtype)
        ys = tf.cast(tf.range(self.model_input.corr_shape[1]) - self.model_input.corr_shape[1] // 2, corr.dtype) / tf.cast(self.model_input.corr_shape[1] // 2, corr.dtype)
        xs = tf.broadcast_to(xs[:, tf.newaxis], self.model_input.corr_shape)
        ys = tf.broadcast_to(ys[tf.newaxis, :], self.model_input.corr_shape)
        pos = tf.stack([xs, ys], axis=-1) # s...

        valid_corr = tf.math.sqrt(tf.reduce_sum(tf.math.square(pos), axis=-1)) <= 1.0
        valid_corr = tfcv.model.einops.apply("s... -> 1 s... 1", valid_corr)
        valid_corr = tf.math.greater(tfcv.model.util.resize(tf.where(valid_corr, 1.0, 0.0), corr_shape, method="bilinear"), 0.0)
        valid_corr = tfcv.model.einops.apply("1 s... 1 -> b h a s...", valid_corr, b=self.model_input.batches, h=self.heads, a=tf.shape(self.model_input.angles)[1])
        valid_corr = tfcv.model.util.set_name(valid_corr, f"correlation-mask-{iteration + 1}")
        valid_corr = tfcv.model.util.set_name(valid_corr, f"correlation-mask-unrotated-{iteration + 1}")

        corr = tfcv.model.util.set_name(corr, f"correlation-logits-{iteration + 1}")
        corr = tf.where(valid_corr, corr, 0.0)

        corr_logits = corr

        return corr_logits, valid_corr
