import numpy as np
import tensorflow as tf
import tfcv, cosy, georeg, types, os, math

class LossBase:
    def prepare(self, corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, metrics):
        bevpixels_to_aerialpixels = loss_input.get_bevpixels_to_aerialpixels_with_shape(
            bev_shape=self.model_constants.bev_shapes[-1],
            aerial_shape=self.model_constants.aerial_shapes[-1],
            model_constants=self.model_constants,
        )
        gt_translation, gt_rotation = bevpixels_to_aerialpixels[..., :2, 2], bevpixels_to_aerialpixels[..., :2, :2]

        new_to_gt = cosy.tf.Rigid.create(rotation=gt_rotation, translation=gt_translation)
        new_to_gt_pred = cosy.tf.Rigid.create(rotation=pred_rotation, translation=pred_translation)

        # Generate error volume for correlation volume
        gt_error = georeg.model.correlation.math.error_volume(new_to_gt, tf.shape(corr)[-2:], model_input.angles)
        pred_error = georeg.model.correlation.math.error_volume(new_to_gt_pred, tf.shape(corr)[-2:], model_input.angles)

        gt_error = tfcv.model.einops.apply("b a s... -> b (a s...)", gt_error)
        pred_error = tfcv.model.einops.apply("b a s... -> b (a s...)", pred_error)
        corr = tfcv.model.einops.apply("b a s... -> b (a s...)", corr)
        valid_corr = tfcv.model.einops.apply("b a s... -> b (a s...)", valid_corr)

        # Sort by groundtruth error
        indices = tf.argsort(gt_error, axis=1, direction="ASCENDING")
        corr = tf.gather(corr, indices, axis=1, batch_dims=1)
        valid_corr = tf.gather(valid_corr, indices, axis=1, batch_dims=1)
        gt_error = tf.gather(gt_error, indices, axis=1, batch_dims=1)
        pred_error = tf.gather(pred_error, indices, axis=1, batch_dims=1)

        # tf.debugging.assert_equal(valid_corr[:, 0], True, message="Got invalid groundtruth correlation")
        valid_sample = valid_corr[:, 0]

        num_misclassified = tf.math.count_nonzero(tf.math.logical_and(
            corr[:, :1] <= corr[:, 1:],
            valid_corr[:, 1:],
        ), axis=1)
        num_total = tf.math.count_nonzero(valid_corr, axis=1) - 1
        misclassified_ratio = tf.cast(num_misclassified, "float32") / tf.cast(num_total, "float32")

        if not metrics is None:
            metrics[self.name + "-num-total"] = num_total
            metrics[self.name + "-misclassified-ratio"] = misclassified_ratio
            metrics[self.name + "-num-misclassified"] = tf.cast(num_misclassified, "float32")

        return types.SimpleNamespace(
            corr=corr,
            valid_corr=valid_corr,
            gt_error=gt_error,
            pred_error=pred_error,
            valid_sample=valid_sample,
        )
#
# class MetricLoss(LossBase):
#     def __init__(self, name, model_constants, loss_params):
#         LossBase.__init__(self)
#         self.name = name
#         os_name = name.upper().replace("-", "_")
#         self.model_constants = model_constants
#
#         self.margin = float(os.environ[f"LOSS_{os_name}_MARGIN"]) if f"LOSS_{os_name}_MARGIN" in os.environ else 0.1
#         self.margin_in_schedule = False
#         loss_params[self.name]["margin"] = self.margin
#         loss_params[self.name]["margin-in-schedule"] = self.margin_in_schedule
#
#         self.loss_denom = str(os.environ[f"LOSS_{os_name}_DENOM"]) if f"LOSS_{os_name}_DENOM" in os.environ else "soft_below_margin"
#         self.temperature = float(os.environ[f"LOSS_{os_name}_TEMPERATURE"]) if f"LOSS_{os_name}_TEMPERATURE" in os.environ else 0.1
#         loss_params[self.name]["denom"] = self.loss_denom
#         loss_params[self.name]["temperature"] = self.temperature
#
#         self.denom_grad = int(os.environ[f"LOSS_{os_name}_DENOM_GRAD"]) == 1 if f"LOSS_{os_name}_DENOM_GRAD" in os.environ else True
#         loss_params[self.name]["denom-grad"] = self.denom_grad
#
#         self.skip_easy_examples = float(os.environ[f"LOSS_{os_name}_SKIP_FACTOR"]) if f"LOSS_{os_name}_SKIP_FACTOR" in os.environ else 0.0
#         self.easy_indicator = str(os.environ[f"LOSS_{os_name}_EASY_INDICATOR"]) if f"LOSS_{os_name}_EASY_INDICATOR" in os.environ else "below_margin_gt_error"
#         loss_params[self.name]["skip-easy-examples-factor"] = self.skip_easy_examples
#         loss_params[self.name]["skip-easy-examples-indicator"] = self.easy_indicator
#
#         self.possible_positives = int(os.environ[f"LOSS_{os_name}_POSSIBLE_POSITIVES"]) if f"LOSS_{os_name}_POSSIBLE_POSITIVES" in os.environ else 1
#         loss_params[self.name]["possible-positives"] = self.possible_positives
#
#         self.weight = float(os.environ[f"LOSS_{os_name}_WEIGHT"]) if f"LOSS_{os_name}_WEIGHT" in os.environ else 1.0
#         loss_params[self.name]["weight"] = self.weight
#
#         self.below_margin_ratio_ema = tfcv.metric.ExponentialMovingAverage(decay=0.97, dtype="float32")
#         self.below_margin_gt_error_ema = tfcv.metric.ExponentialMovingAverage(decay=0.97, dtype="float32")
#         self.soft_num_below_margin_ema = tfcv.metric.ExponentialMovingAverage(decay=0.97, dtype="float32")
#
#     def __call__(self, corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, schedule_factor, metrics):
#         data = self.prepare(corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, metrics)
#
#         if self.margin_in_schedule:
#             margin_this_iteration = schedule_factor * self.margin
#         else:
#             margin_this_iteration = self.margin
#
#         positive_distance = -data.corr[:, 0]
#         negatives_distance = -data.corr[:, self.possible_positives:]
#         negatives_valid = data.valid_corr[:, self.possible_positives:]
#         negatives_gt_error = data.gt_error[:, self.possible_positives:]
#
#         below_margin = tf.math.logical_and(
#             positive_distance[:, tf.newaxis] + margin_this_iteration > negatives_distance,
#             negatives_valid,
#         )
#         num_below_margin = tf.math.count_nonzero(below_margin, axis=1)
#         num_total = tf.math.count_nonzero(negatives_valid, axis=1)
#
#         below_margin_gt_error = tf.reduce_max(tf.where(below_margin, negatives_gt_error, 0.0), axis=1)
#         self.below_margin_gt_error_ema.update_state(tf.reduce_mean(below_margin_gt_error))
#
#         loss = (positive_distance[:, tf.newaxis] + margin_this_iteration) - negatives_distance
#         if self.temperature == 0:
#             loss = tf.math.maximum(loss, 0.0) # [batch, :]
#             soft_below_margin = tf.cast(tf.where(below_margin, 1.0, 0.0), loss.dtype) # [batch, :]
#         else:
#             loss = georeg.model.util.softplus(loss, temperature=self.temperature) # [batch, :]
#             soft_below_margin = tf.math.sigmoid(loss / self.temperature) # [batch, :]
#
#         loss = tf.where(negatives_valid, loss, 0.0)
#         soft_below_margin = tf.where(negatives_valid, soft_below_margin, 0.0)
#
#         soft_below_margin_num = tf.reduce_mean(tf.reduce_sum(soft_below_margin, axis=1))
#         self.soft_num_below_margin_ema.update_state(soft_below_margin_num)
#
#         if self.loss_denom == "total":
#             denom = num_total
#         elif self.loss_denom == "soft_below_margin":
#             denom = tf.reduce_sum(soft_below_margin, axis=1)
#         elif self.loss_denom == "below_margin":
#             denom = num_below_margin
#         elif self.loss_denom == "soft_below_margin_ema":
#             metrics[self.name + "-soft-below-margin-num-ema"] = self.soft_num_below_margin_ema.result()
#             denom = tf.stop_gradient(self.soft_num_below_margin_ema.result())
#         else:
#             raise ValueError(f"Invalid value for loss_denom: {self.loss_denom}")
#         if not self.denom_grad:
#             denom = tf.stop_gradient(denom)
#         loss = tf.math.divide_no_nan(tf.reduce_sum(loss, axis=1), tf.cast(denom, loss.dtype)) # [batch]
#
#         below_margin_ratio = tf.cast(num_below_margin, "float32") / tf.cast(num_total, "float32")
#         self.below_margin_ratio_ema.update_state(tf.reduce_mean(below_margin_ratio))
#
#         if not metrics is None:
#             metrics[self.name + "-soft-below-margin-num"] = soft_below_margin_num
#             metrics[self.name + "-soft-below-margin-ratio"] = tf.cast(soft_below_margin_num, "float32") / tf.cast(num_total, "float32")
#             metrics[self.name + "-loss"] = tf.reduce_mean(loss)
#             metrics[self.name + "-below-margin-ratio"] = tf.reduce_mean(below_margin_ratio)
#             metrics[self.name + "-below-margin-num"] = tf.reduce_mean(num_below_margin)
#             metrics[self.name + "-below-margin-ema"] = self.below_margin_ratio_ema.result()
#
#         if self.skip_easy_examples > 0:
#             if self.easy_indicator == "below_margin_ratio":
#                 x = below_margin_ratio
#                 x_ema = self.below_margin_ratio_ema
#             elif self.easy_indicator == "below_margin_gt_error":
#                 x = below_margin_gt_error
#                 x_ema = self.below_margin_gt_error_ema
#             else:
#                 assert False
#
#             kept = tf.math.logical_or(
#                 x > x_ema.result() * self.skip_easy_examples,
#                 x_ema.num < 0,
#             )
#             loss = tf.where(kept, loss, 0.0)
#             metrics[self.name + "-skipped-easy-sample"] = tf.reduce_mean(tf.where(kept, 0.0, 1.0))
#
#         return self.weight * loss

class CrossEntropyLoss(LossBase):
    def __init__(self, name, model_constants, loss_params):
        LossBase.__init__(self)
        self.name = name
        os_name = name.upper().replace("-", "_")
        self.model_constants = model_constants

        self.possible_positives = 1

        self.angle_std = math.radians(float(os.environ[f"LOSS_{os_name}_ANGLE_STD"]) if f"LOSS_{os_name}_ANGLE_STD" in os.environ else 2.0)
        loss_params[self.name]["angle-std"] = math.degrees(self.angle_std)

        self.translation_std = float(os.environ[f"LOSS_{os_name}_TRANSLATION_STD"]) if f"LOSS_{os_name}_TRANSLATION_STD" in os.environ else 1.0
        loss_params[self.name]["translation-std"] = self.translation_std

        self.label_smoothing = float(os.environ[f"LOSS_{os_name}_LABELSMOOTH"]) if f"LOSS_{os_name}_LABELSMOOTH" in os.environ else 0.0
        loss_params[self.name]["label-smooth"] = self.label_smoothing

        covariance = np.zeros(shape=(3, 3), dtype="float32")
        covariance[0, 0] = self.angle_std * self.angle_std
        translation_std_pixels = self.translation_std / model_constants.meters_per_pixel[-1]
        covariance[1, 1] = covariance[2, 2] = translation_std_pixels * translation_std_pixels
        self.inv_covariance = np.linalg.inv(covariance)

        self.weight = float(os.environ[f"LOSS_{os_name}_WEIGHT"]) if f"LOSS_{os_name}_WEIGHT" in os.environ else 1.0
        loss_params[self.name]["weight"] = self.weight

    def __call__(self, corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, schedule_factor, metrics):
        data = self.prepare(corr, valid_corr, model_input, loss_input, pred_rotation, pred_translation, metrics)

        bevpixels_to_aerialpixels = loss_input.get_bevpixels_to_aerialpixels_with_shape(
            bev_shape=self.model_constants.bev_shapes[-1],
            aerial_shape=self.model_constants.aerial_shapes[-1],
            model_constants=self.model_constants,
        )
        gt_translation, gt_rotation = bevpixels_to_aerialpixels[..., :2, 2], bevpixels_to_aerialpixels[..., :2, :2]

        gt_axy = tf.concat([
            cosy.tf.rotation_matrix_to_angle(gt_rotation)[..., tf.newaxis],
            gt_translation,
        ], axis=-1) # b 3

        corr_axy = georeg.model.correlation.math.axy_volume(tf.shape(corr)[-2:], angles=model_input.angles, dtype="float32") # b a s... 3
        corr_axy = tfcv.model.einops.apply("b a s... f -> b (a s...) f", corr_axy)
        valid_corr = tfcv.model.einops.apply("b a s... -> b (a s...)", valid_corr)

        x = corr_axy - gt_axy[:, tf.newaxis, :]
        gt_probs = tfcv.model.einops.apply("b p f0, f0 f1, b p f1 -> b p", x, self.inv_covariance, x)
        gt_probs = tf.math.exp(-0.5 * gt_probs)
        gt_probs = tf.where(valid_corr, gt_probs, 0.0)
        gt_probs = gt_probs / tf.reduce_sum(gt_probs, axis=1, keepdims=True) # b p

        if self.label_smoothing > 0:
            gt_probs_uniform = tf.cast(tf.where(valid_corr, 1.0, 0.0), "float32")
            gt_probs_uniform = gt_probs_uniform / tf.reduce_sum(gt_probs_uniform, axis=1, keepdims=True) # b p

            gt_probs = self.label_smoothing * gt_probs_uniform + (1.0 - self.label_smoothing) * gt_probs

        pred_probs = tfcv.model.einops.apply("b a s... -> b (a s...)", corr) # b p

        loss = -gt_probs * tf.where(pred_probs > 0.0, tf.math.log(tf.where(pred_probs > 0.0, pred_probs, 1.0)), 0.0)
        loss = tf.reduce_sum(loss, axis=1)
        loss = tf.debugging.assert_all_finite(loss, message="non-finite loss")

        if not metrics is None:
            metrics[self.name + "-loss"] = loss
            metrics[self.name + "-max-pred-prob"] = tf.reduce_max(pred_probs, axis=1)
            metrics[self.name + "-max-gt-prob"] = tf.reduce_max(gt_probs, axis=1)
            metrics[self.name + "-prob-at-gt"] = tf.gather(pred_probs, tf.argmax(gt_probs, axis=1), axis=1, batch_dims=1)

            pred_axy = tf.concat([
                cosy.tf.rotation_matrix_to_angle(pred_rotation)[..., tf.newaxis],
                pred_translation,
            ], axis=-1) # b 3

            x = corr_axy - pred_axy[:, tf.newaxis, :] # b p 3
            x = tf.linalg.matmul(x[..., :, tf.newaxis], x[..., tf.newaxis, :]) # b p 3 3
            x = tf.reduce_sum(x * pred_probs[..., tf.newaxis, tf.newaxis], axis=1) # b 3 3
            x = tf.linalg.pinv(x)
            pred_inv_cov = x

            x = gt_axy - pred_axy
            mahalanobis_distance = tfcv.model.einops.apply("b f0, b f0 f1, b f1 -> b", x, pred_inv_cov, x)
            mahalanobis_distance = tf.math.sqrt(mahalanobis_distance)

            metrics[self.name + "-md"] = mahalanobis_distance

            x = corr_axy - gt_axy[:, tf.newaxis, :]
            x = tf.math.square(x) # b p 3
            x = tf.reduce_sum(x * pred_probs[..., tf.newaxis], axis=1) # b 3
            metrics[self.name + "-std-gtmean"] = tf.math.sqrt(tf.reduce_sum(x, axis=1))

            x = corr_axy - pred_axy[:, tf.newaxis, :]
            x = tf.math.square(x) # b p 3
            x = tf.reduce_sum(x * pred_probs[..., tf.newaxis], axis=1) # b 3
            metrics[self.name + "-std-predmean"] = tf.math.sqrt(tf.reduce_sum(x, axis=1))

        return self.weight * loss, data.valid_sample
