import georeg, math
import numpy as np
import tensorflow as tf

class IntegerArgmax:
    def __call__(self, corr, angles):
        translation, rotation = georeg.model.phase_correlate.correlation_argmax(corr[tf.newaxis], angles)
        translation = translation[0]
        rotation = rotation[0]
        return georeg.transform.tf.Rigid.numpy(georeg.transform.tf.Rigid.create(rotation=rotation, translation=translation))

class WeightedCovariance:
    def __init__(self, window_radius, rescale=lambda x: (x + 1.0) / 2.0):
        self.window_radius = np.asarray(window_radius)
        self.rescale = rescale
        self.tf_func = None

    def __call__(self, correlation, angle_to_world, pred_transform, angles):
        if self.tf_func is None:
            def tf_func(correlation, rotation_matrix_to_world, pred_transform_rotation, pred_transform_translation, angles, window_radius):
                tf.debugging.assert_greater_equal(correlation, 0.0)
                tf.debugging.assert_less_equal(correlation, 1.0)
                tf.debugging.assert_greater(tf.reduce_sum(correlation), 0.0)

                axy = georeg.model.register.axy_volume(tf.shape(correlation)[-2:], angles=angles, dtype="float32") # [angles, dims..., 3]
                axy = tf.reshape(axy, [-1, 3])
                axy = tf.concat([
                    axy[:, :1],
                    (tf.matmul(rotation_matrix_to_world[tf.newaxis, :, :], axy[:, 1:, tf.newaxis]))[:, :, 0]
                ], axis=-1)
                axy = tf.reshape(axy, tf.stack([tf.shape(angles)[0], tf.shape(correlation)[-2], tf.shape(correlation)[-1], 3]))
                mean = tf.stack([georeg.transform.tf.rotation_matrix_to_angle(pred_transform_rotation), pred_transform_translation[0], pred_transform_translation[1]])
                mean = tf.concat([
                    mean[:1],
                    tf.matmul(rotation_matrix_to_world, mean[1:, tf.newaxis])[:, 0]
                ], axis=0)

                argmax = tf.cast(georeg.model.register.argmax3d(correlation), window_radius.dtype)
                window_min = tf.math.maximum(argmax - window_radius, 0)
                window_max = tf.math.minimum(argmax + window_radius, tf.cast(tf.shape(correlation) - 1, argmax.dtype))
                axy_window = axy[window_min[0]:window_max[0], window_min[1]:window_max[1], window_min[2]:window_max[2]]
                corr_window = correlation[window_min[0]:window_max[0], window_min[1]:window_max[1], window_min[2]:window_max[2]]
                tf.debugging.assert_greater(tf.reduce_sum(corr_window), 0.0)

                axy_window = tf.reshape(axy_window, [-1, 3])
                corr_window = tf.reshape(corr_window, [-1])
                x = axy_window - mean[tf.newaxis, :]
                x = x[:, :, tf.newaxis] @ x[:, tf.newaxis, :] # [-1, 3, 3]
                covariance_matrix = tf.reduce_sum(x * corr_window[:, tf.newaxis, tf.newaxis], axis=0) / tf.reduce_sum(corr_window)

                axy = tf.reshape(axy, [-1, 3])
                correlation = tf.reshape(correlation, [-1])
                x = axy - mean[tf.newaxis, :]
                x = tf.reduce_sum(x[:, 1:] ** 2, axis=1)
                translation_variance = tf.reduce_sum(x * correlation, axis=0) / tf.reduce_sum(correlation)

                # pos_uncertainty, pos_uncertainty_rotation_matrix = np.linalg.eig(cov[1:3, 1:3])
                # if pos_uncertainty[0] < pos_uncertainty[1]:
                #     pos_uncertainty = np.flip(pos_uncertainty, axis=0)
                #     pos_uncertainty_rotation_matrix = np.flip(pos_uncertainty_rotation_matrix, axis=1)
                # pos_uncertainty = np.sqrt(pos_uncertainty)
                # pos_uncertainty_angle = georeg.transform.np.rotation_matrix_to_angle(pos_uncertainty_rotation_matrix)
                #
                # pos_uncertainty_vector1 = pos_uncertainty_rotation_matrix @ np.asarray([pos_uncertainty[0], 0])
                # pos_uncertainty_vector2 = pos_uncertainty_rotation_matrix @ np.asarray([0, pos_uncertainty[1]])
                #
                # gt_error_in_pos_uncertainty_vector1 = np.abs(np.dot(gt_translation, pos_uncertainty_vector1) / np.linalg.norm(pos_uncertainty_vector1))
                # gt_error_in_pos_uncertainty_vector2 = np.abs(np.dot(gt_translation, pos_uncertainty_vector2) / np.linalg.norm(pos_uncertainty_vector2))
                #
                #
                # pos_uncertainty_ratio = pos_uncertainty[0] / pos_uncertainty[1]
                # gt_error_in_pos_uncertainty_ratio = gt_error_in_pos_uncertainty_vector1 / gt_error_in_pos_uncertainty_vector2

                return covariance_matrix, translation_variance
            self.tf_func = tf_func
            # self.tf_func = tf.function(tf_func, input_signature=[
            #     tf.TensorSpec([correlation.shape[0], correlation.shape[1], correlation.shape[2]], dtype="float32"),
            #     tf.TensorSpec([2, 2], dtype="float32"),
            #     tf.TensorSpec([2, 2], dtype="float32"),
            #     tf.TensorSpec([2], dtype="float32"),
            #     tf.TensorSpec([angles.shape[0]], dtype="float32"),
            #     tf.TensorSpec([3], dtype="int32"),
            # ])

        correlation = self.rescale(correlation)
        return self.tf_func(
            tf.cast(correlation, "float32"),
            tf.cast(georeg.transform.np.angle_to_rotation_matrix(angle_to_world), "float32"),
            tf.cast(pred_transform.rotation, "float32"),
            tf.cast(pred_transform.translation, "float32"),
            tf.cast(angles, "float32"),
            tf.cast(self.window_radius, "int32"),
        )
