import math, cosy, georeg, tfcv
import numpy as np
import tensorflow as tf
from collections import OrderedDict

def gaussian_probs(axy_curr_to_world, mean, covariance):
    axy_curr_to_world = tf.cast(axy_curr_to_world, "float64")
    mean = tf.cast(mean, "float64")
    covariance = tf.cast(covariance, "float64")
    try:
        probs = tf.math.exp(-0.5 * (
            tf.matmul(
                tf.matmul(
                    (axy_curr_to_world - mean[tf.newaxis, :])[:, tf.newaxis, :],
                    tf.linalg.inv(covariance)[tf.newaxis, :, :],
                ),
                (axy_curr_to_world - mean[tf.newaxis, :])[:, :, tf.newaxis],
            )[:, 0, 0]
        )) # Dont need this: 1.0 / tf.cast(tf.math.sqrt((2 * math.pi) ** 2 * tf.linalg.det(covariance)), "float32")
    except:
        assert False, f"Covariance not invertible: {covariance.numpy()}"
    return probs.numpy()

class ProbabilityDistribution:
    @staticmethod
    def gaussian(axy_curr_to_world, mean, covariance, ref_to_world):
        return ProbabilityDistribution(
            probs=gaussian_probs(axy_curr_to_world, mean, covariance),
            mean=mean,
            covariance=covariance,
            ref_to_world=ref_to_world,
        )

    @staticmethod
    def from_probs_simple(axy_curr_to_world, probs, ref_to_world):
        axy_curr_to_world = axy_curr_to_world.astype("float64")
        probs = probs.astype("float64")
        assert np.sum(probs) > 0
        probs = probs / np.sum(probs)

        index = np.argmax(probs)
        mean = axy_curr_to_world[index]
        # mean = np.sum(probs[:, np.newaxis] * axy_curr_to_world, axis=0)

        # probs = tf.where(probs < 1e-8, 0.0, tf.math.exp(pow * tf.math.log(probs)))
        x = axy_curr_to_world - mean[np.newaxis, :]
        x = x[:, :, np.newaxis] @ x[:, np.newaxis, :]
        x = np.sum(x * probs[:, np.newaxis, np.newaxis], axis=0)

        eig_val, eig_vec = np.linalg.eigh(x)
        eig_val = np.abs(eig_val)
        assert np.all(eig_val > 0)
        min_std = 0.001
        eig_val = np.maximum(eig_val, min_std * min_std)
        x = eig_vec @ np.diag(eig_val) @ np.transpose(eig_vec, (1, 0))

        return ProbabilityDistribution(
            probs=probs.astype("float32"),
            mean=mean.astype("float32"),
            covariance=x.astype("float32"),
            ref_to_world=ref_to_world,
        )

    def __init__(self, ref_to_world, probs, mean, covariance):
        self.ref_to_world = ref_to_world
        self.probs = probs
        self.mean = mean
        self.covariance = covariance

    curr_to_world = property(lambda self: cosy.np.Rigid(
        translation=self.mean[1:],
        rotation=cosy.np.angle_to_rotation_matrix(self.mean[0]),
    ))

    curr_to_ref = property(lambda self: self.ref_to_world.inverse() * self.curr_to_world)

class TrackingStep:
    def __init__(self, **kwargs):
        vars(self).update(kwargs)

class KalmanTracker:
    def __init__(self, predictor, tileloader, filter, timestamp, latlon, bearing, recalibrate, process_noise_multiplier=1.0, **kwargs):
        self.predictor = predictor
        self.tileloader = tileloader

        self.last_timestamp = timestamp
        self.world_to_epsg3857 = cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(latlon)

        self.epsg4326_to_epsg3857 = predictor.epsg4326_to_epsg3857
        self.epsg3857_to_epsg4326 = predictor.epsg3857_to_epsg4326

        initial_to_world = self.from_epsg4326(latlon, bearing)
        kwargs["position"]["mean"] = initial_to_world.translation
        kwargs["heading"]["mean"] = cosy.np.rotation_matrix_to_angle(initial_to_world.rotation)
        print(f"initial_to_world={initial_to_world}")

        self.recalibrate = recalibrate

        self.filter = filter(**kwargs)

        # TODO: where to put this
        def process_noise(dt, filter):
            std_acceleration = 10.0 * process_noise_multiplier
            std_yawrate = math.radians(45.0) * process_noise_multiplier

            return {
                "position": 1.0 / 6.0 * std_acceleration * dt ** 3 * 100,
                "heading": 1.0 / 2.0 * std_yawrate * dt ** 2,
                "velocity": 1.0 / 2.0 * std_acceleration * dt ** 2,
                "yawrate": std_yawrate * dt,
                "acceleration": std_acceleration * dt,
            }






            if False:
                std_acceleration = 3.0
                std_yawrate = math.radians(7.0)

                var_acc = std_acceleration * std_acceleration
                var_yawrate = std_yawrate * std_yawrate

                heading = filter.get_mean("heading")
                velocity = filter.get_mean("velocity")

                std_pos = 1.0

                Q = np.zeros([6, 6], dtype="float32")

                # Q[0, 0] = var_acc * dt ** 5 / 20
                # Q[1, 1] = var_acc * dt ** 5 / 20
                # Q[2, 2] = var_acc * dt ** 3 / 3
                # Q[3, 3] = var_yawrate * dt ** 3 / 3
                # Q[4, 4] = var_yawrate * dt
                # Q[5, 5] = var_acc * dt


                # Q[0, 0] = var_acc * dt ** 5 / 20
                # Q[1, 1] = var_acc * dt ** 5 / 20
                # Q[2, 2] = var_acc * dt ** 3 / 3
                # Q[3, 3] = var_yawrate * dt ** 3 / 3
                # Q[4, 4] = var_yawrate * dt
                # Q[5, 5] = var_acc * dt

                o_w = std_yawrate
                o_w2 = o_w * o_w
                s = velocity
                s2 = s * s
                o_a = std_acceleration
                o_a2 = o_a * o_a
                sin = math.sin(heading)
                cos = math.cos(heading)
                sin2 = sin * sin
                cos2 = cos * cos

                # https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8916654
                Q[0, 0] = (o_w2 * s2 * sin2 + o_a2 * cos2) * dt ** 5 / 20 + std_pos * std_pos * dt
                Q[0, 1] = Q[1, 0] = (o_a2 - o_w2 * s2) * sin * cos * dt ** 5 / 20
                Q[0, 2] = Q[2, 0] = o_a2 * dt ** 4 / 8 * cos
                Q[0, 3] = Q[3, 0] = -o_w2 * dt ** 4 / 8 * s * sin
                Q[0, 4] = Q[4, 0] = -o_w2 * dt ** 3 / 6 * s * sin
                Q[0, 5] = Q[5, 0] = o_a2 * dt ** 3 / 6 * cos

                Q[1, 1] = (o_w2 * s2 * cos2 + o_a2 * sin2) * dt ** 5 / 20 + std_pos * std_pos * dt
                Q[1, 2] = Q[2, 1] = o_a2 * dt ** 4 / 8 * sin
                Q[1, 3] = Q[3, 1] = o_w2 * dt ** 4 / 8 * s * cos
                Q[1, 4] = Q[4, 1] = o_w2 * dt ** 3 / 6 * s * cos
                Q[1, 5] = Q[5, 1] = o_a2 * dt ** 3 / 6 * sin

                Q[2, 2] = o_a2 * dt ** 3 / 3
                Q[2, 3] = Q[3, 2] = 0
                Q[2, 4] = Q[4, 2] = 0
                Q[2, 5] = Q[5, 2] = o_a2 * dt ** 2 / 2

                Q[3, 3] = o_w2 * dt ** 3 / 3
                Q[3, 4] = Q[4, 3] = o_w2 * dt ** 2 / 2
                Q[3, 5] = Q[5, 3] = 0

                Q[4, 4] = o_w2 * dt
                Q[4, 5] = Q[5, 4] = 0

                Q[5, 5] = o_a2 * dt

                return {"Q": Q}
        self.process_noise = process_noise

    def to_epsg4326(self, curr_to_world):
        curr_to_epsg3857 = self.world_to_epsg3857 * curr_to_world

        latlon = self.epsg3857_to_epsg4326(curr_to_epsg3857.translation.astype("float64"))
        bearing = math.degrees(self.epsg3857_to_epsg4326.transform_angle(cosy.np.rotation_matrix_to_angle(curr_to_epsg3857.rotation.astype("float64"))))
        return latlon, bearing

    def from_epsg4326(self, latlon, bearing):
        curr_to_epsg3857 = cosy.np.Rigid(
            rotation=cosy.np.angle_to_rotation_matrix(self.epsg4326_to_epsg3857.transform_angle(math.radians(bearing))).astype("float64"),
            translation=self.epsg4326_to_epsg3857(latlon).astype("float64"),
        )
        curr_to_world = self.world_to_epsg3857.inverse() * curr_to_epsg3857
        return curr_to_world

    curr_to_world = property(lambda self: cosy.np.Rigid(
        translation=self.filter.get_mean("position"),
        rotation=cosy.np.angle_to_rotation_matrix(self.filter.get_mean("heading")),
    ))

    latlon = property(lambda self: self.to_epsg4326(self.curr_to_world)[0])
    bearing = property(lambda self: self.to_epsg4326(self.curr_to_world)[1])

    def get_axy(self):
        fields = ["heading", "position_0", "position_1"]
        mean = np.asarray([self.filter.get_mean(f) for f in fields])
        covariance = np.asarray([[self.filter.get_covariance(f0, f1) for f1 in fields] for f0 in fields])
        return mean, covariance

    def __call__(self, aligned_ground_frame, update_covariance_multiplier=[1.0, 1.0, 1.0], weak_prior_covariance_multiplier=[1.0, 1.0, 1.0], pred_prob_power=1.0, use_imu=True, use_prediction=True):
        update_covariance_multiplier = np.asarray(update_covariance_multiplier)
        update_covariance_multiplier = update_covariance_multiplier[:, np.newaxis] * update_covariance_multiplier[np.newaxis, :]

        weak_prior_covariance_multiplier = np.asarray(weak_prior_covariance_multiplier)
        weak_prior_covariance_multiplier = weak_prior_covariance_multiplier[:, np.newaxis] * weak_prior_covariance_multiplier[np.newaxis, :]

        probability_distributions = OrderedDict()

        prev_mean, prev_covariance = self.get_axy()
        # print(f"prev_mean={prev_mean} prev_covariance={prev_covariance}")

        # Advance timestep in filter
        dt = (aligned_ground_frame.timestamp - self.last_timestamp) * 1e-6 # seconds
        print(f"dt={dt}")
        self.filter.predict(dt=dt, **self.process_noise(dt, self.filter))
        self.last_timestamp = aligned_ground_frame.timestamp

        prior_mean, prior_covariance = self.get_axy()
        # print(f"prior_mean={prior_mean} prior_covariance={prior_covariance}")
        ref_to_world = self.curr_to_world

        axy_curr_to_ref = georeg.model.correlation.math.axy_volume(self.predictor.corr_shape, angles=[self.predictor.angles], dtype="float64")[0].numpy()
        axy_curr_to_ref[..., 1:] *= self.predictor.model_constants.meters_per_pixel[-1]
        axy_curr_to_ref = tfcv.model.einops.apply("a s... f -> (a s...) f", axy_curr_to_ref)
        axy_curr_to_world = tf.concat([
            axy_curr_to_ref[:, :1] + cosy.np.rotation_matrix_to_angle(ref_to_world.rotation),
            tf.matmul(ref_to_world.rotation[tf.newaxis, :, :], axy_curr_to_ref[:, 1:, tf.newaxis])[:, :, 0] + ref_to_world.translation[tf.newaxis, :]
        ], axis=-1)
        axy_curr_to_ref, axy_curr_to_world = axy_curr_to_ref.numpy(), axy_curr_to_world.numpy()

        prev = probability_distributions["prev"] = ProbabilityDistribution.gaussian(axy_curr_to_world, prev_mean, prev_covariance, ref_to_world)
        prior = probability_distributions["prior"] = ProbabilityDistribution.gaussian(axy_curr_to_world, prior_mean, prior_covariance, ref_to_world)

        filter_args = {}

        # Predict update
        prior_latlon, prior_bearing = self.to_epsg4326(self.curr_to_world)
        prediction, frame = self.predictor(aligned_ground_frame, self.tileloader, prior_latlon, prior_bearing)
        if not self.recalibrate:
            update = probability_distributions["update"] = ProbabilityDistribution.from_probs_simple(
                axy_curr_to_world=axy_curr_to_world,
                probs=np.power(prediction.scores, pred_prob_power),
                ref_to_world=ref_to_world,
            )
        else:
            weak_prior = probability_distributions["weak_prior"] = ProbabilityDistribution.gaussian(
                axy_curr_to_world=axy_curr_to_world,
                mean=prior_mean,
                covariance=weak_prior_covariance_multiplier * prior_covariance,
                ref_to_world=ref_to_world,
            )

            uncalibrated_update = probability_distributions["uncalibrated_update"] = ProbabilityDistribution.from_probs_simple(
                axy_curr_to_world=axy_curr_to_world,
                probs=np.power(prediction.scores, pred_prob_power),
                ref_to_world=ref_to_world,
            )

            probs = weak_prior.probs.astype("float64") * weak_prior.probs.shape[0] * uncalibrated_update.probs.astype("float64")
            probs = probs / np.sum(probs)
            probs = probs.astype("float32")
            F_uncalibrated_update = probability_distributions["F_uncalibrated_update"] = ProbabilityDistribution.from_probs_simple(
                axy_curr_to_world=axy_curr_to_world,
                probs=probs,
                ref_to_world=ref_to_world,
            )

            if False:
                F_uncalibrated_update = ProbabilityDistribution.gaussian(
                    axy_curr_to_world=axy_curr_to_world,
                    mean=F_uncalibrated_update.mean,
                    covariance=F_uncalibrated_update.covariance,
                    ref_to_world=ref_to_world,
                )

                with np.errstate(divide="ignore", invalid="ignore"):
                    update = ProbabilityDistribution.from_probs_simple(
                        axy_curr_to_world=axy_curr_to_world,
                        probs=np.where(weak_prior.probs > 0, F_uncalibrated_update.probs / weak_prior.probs, 0),
                        ref_to_world=ref_to_world,
                    )

                update = ProbabilityDistribution.gaussian(
                    axy_curr_to_world=axy_curr_to_world,
                    mean=update.mean,
                    covariance=update.covariance * update_covariance_multiplier,
                    ref_to_world=ref_to_world,
                )

            if True:
                calibrated_update_covariance = weak_prior.covariance @ tf.linalg.inv(weak_prior.covariance - F_uncalibrated_update.covariance) @ F_uncalibrated_update.covariance
                # calibrated_update_covariance = tf.linalg.inv(tf.linalg.inv(F_uncalibrated_update.covariance) - tf.linalg.inv(weak_prior.covariance))

                eig_val, eig_vec = tf.linalg.eigh(calibrated_update_covariance)
                eig_val = tf.math.abs(eig_val)
                calibrated_update_covariance = eig_vec @ tf.linalg.diag(eig_val) @ tf.transpose(eig_vec, (1, 0))

                update = ProbabilityDistribution.gaussian(
                    axy_curr_to_world=axy_curr_to_world,
                    mean=F_uncalibrated_update.mean,
                    covariance=calibrated_update_covariance * update_covariance_multiplier,
                    ref_to_world=ref_to_world,
                )

            probability_distributions["update"] = update

        # print(f"update.mean={update.mean} update.covariance={update.covariance}")

        if use_prediction:
            filter_args["axy"] = {
                "mean": update.mean,
                "covariance": update.covariance,
            }

        if use_imu and aligned_ground_frame.imu.has_angular_velocity():
            filter_args["yawrate"] = {
                "mean": aligned_ground_frame.imu.angular_velocity[2],
                "std": math.radians(5.0),
            }
        if use_imu and aligned_ground_frame.imu.has_linear_acceleration():
            filter_args["acceleration"] = {
                "mean": aligned_ground_frame.imu.linear_acceleration[0],
                "std": 0.5,
            }
        if len(filter_args) > 0:
            self.filter.update(**filter_args)

        posterior_mean, posterior_covariance = self.get_axy()
        # print(f"posterior_mean={posterior_mean} posterior_covariance={posterior_covariance}")
        posterior = probability_distributions["posterior"] = ProbabilityDistribution.gaussian(axy_curr_to_world, posterior_mean, posterior_covariance, ref_to_world)

        return TrackingStep(
            probability_distributions=probability_distributions,
            frame=frame,
        )
