import tensorflow as tf
import numpy as np
import cv2, imageio, cosy, georeg, os, math, tfcv

class CorrelationLoss:
    def __init__(self, name, model_constants, loss_params):
        self.name = name
        os_name = name.upper().replace("-", "_")
        self.model_constants = model_constants

        self.pred_type = "discrete"

        type = str(os.environ[f"LOSS_{os_name}_TYPE"]) if f"LOSS_{os_name}_TYPE" in os.environ else "contrastive"
        loss_params[self.name]["type"] = type

        if type == "contrastive":
            self.loss = georeg.model.correlation.loss.MetricLoss(name, model_constants, loss_params)
        elif type == "entropy":
            self.loss = georeg.model.correlation.loss.CrossEntropyLoss(name, model_constants, loss_params)
        else:
            assert False

    def get_output(self, model, model_input, loss_input, config, model_params):
        result = {
            f"rotation-{self.pred_type}": model.get_layer(f"rotation-{self.pred_type}").output,
            f"translation-{self.pred_type}": model.get_layer(f"translation-{self.pred_type}").output,
        }

        layer_names = [l.name for l in model.layers]

        for i in range(1, self.model_constants.blocks + 1):
            if f"correlation-logits-{i}" in layer_names:
                result[f"correlation-logits-{i}"] = model.get_layer(f"correlation-logits-{i}").output
                result[f"correlation-mask-{i}"] = model.get_layer(f"correlation-mask-{i}").output
        return result

    def loss_fn(self, model_output, model_input, loss_input, schedule_factor, metrics):
        corr = model_output[f"correlation-logits-{self.model_constants.blocks}"]
        valid_corr = model_output[f"correlation-mask-{self.model_constants.blocks}"]

        corr = tfcv.model.einops.apply("b h a s... -> b a s...", corr, reduction="mean")
        valid_corr = tfcv.model.einops.apply("b h a s... -> b a s...", valid_corr, reduction="all")
        corr = georeg.model.correlation.math.softmax(corr, valid_corr)

        pred_rotation = model_output[f"rotation-{self.pred_type}"]
        pred_translation = model_output[f"translation-{self.pred_type}"]

        return self.loss(corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, schedule_factor, metrics)
