import numpy as np
from collections import defaultdict

class UnnamedCTRA:
    fields = ["position_0", "position_1", "heading", "velocity", "yawrate", "acceleration"]

    def __init__(self, x, P):
        self.x = x
        self.P = P
        self.fields = UnnamedCTRA.fields

    def predict(self, dt, Q):
        x = self.x

        # Project the state ahead
        if np.abs(x[4]) < 0.0001: # Driving straight
            x[0] = x[0] + x[3] * dt * np.cos(x[2])
            x[1] = x[1] + x[3] * dt * np.sin(x[2])
            x[2] = x[2]
            x[3] = x[3] + x[5] * dt
            x[4] = 0.0000001 # avoid numerical issues in Jacobians
            x[5] = x[5]
        else:
            x[0] = x[0] + (x[3]/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2]))
            x[1] = x[1] + (x[3]/x[4]) * (-np.cos(x[4]*dt+x[2])+ np.cos(x[2]))
            x[2] = (x[2] + x[4]*dt + np.pi) % (2.0*np.pi) - np.pi
            x[3] = x[3] + x[5]*dt
            x[4] = x[4] # Constant Turn Rate
            x[5] = x[5] # Constant Acceleration

        # Calculate the Jacobian of the Dynamic Matrix A
        a13 = float((x[3]/x[4]) * (np.cos(x[4]*dt+x[2]) - np.cos(x[2])))
        a14 = float((1.0/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a15 = float((dt*x[3]/x[4])*np.cos(x[4]*dt+x[2]) - (x[3]/x[4]**2)*(np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a23 = float((x[3]/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a24 = float((1.0/x[4]) * (-np.cos(x[4]*dt+x[2]) + np.cos(x[2])))
        a25 = float((dt*x[3]/x[4])*np.sin(x[4]*dt+x[2]) - (x[3]/x[4]**2)*(-np.cos(x[4]*dt+x[2]) + np.cos(x[2])))
        JA = np.asarray([
            [1.0, 0.0, a13, a14, a15, 0.0],
            [0.0, 1.0, a23, a24, a25, 0.0],
            [0.0, 0.0, 1.0, 0.0, dt, 0.0],
            [0.0, 0.0, 0.0, 1.0, 0.0, dt],
            [0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
        ])

        # Project the error covariance ahead
        self.P = JA @ self.P @ JA.T + Q

    def update(self, z, R, H):
        """
        :param z: measurement vector
        :param R: measurement covariance matrix
        :param H: observation matrix
        """
        S = H @ self.P @ H.T + R
        K = self.P @ H.T @ np.linalg.inv(S)

        self.x = self.x + K @ (z - H @ self.x)
        I = np.eye(H.shape[1])
        self.P = (I - K @ H) @ self.P

class UnnamedCTRV:
    fields = ["position_0", "position_1", "heading", "velocity", "yawrate"]

    def __init__(self, x, P):
        self.x = x
        self.P = P
        self.fields = UnnamedCTRV.fields

    def predict(self, dt, Q):
        x = self.x

        # Project the state ahead
        # see "Dynamic Matrix"
        if np.abs(x[4])<0.0001: # Driving straight
            x[0] = x[0] + x[3]*dt * np.cos(x[2])
            x[1] = x[1] + x[3]*dt * np.sin(x[2])
            x[2] = x[2]
            x[3] = x[3]
            x[4] = 0.0000001 # avoid numerical issues in Jacobians
        else: # otherwise
            x[0] = x[0] + (x[3]/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2]))
            x[1] = x[1] + (x[3]/x[4]) * (-np.cos(x[4]*dt+x[2])+ np.cos(x[2]))
            x[2] = (x[2] + x[4]*dt + np.pi) % (2.0*np.pi) - np.pi
            x[3] = x[3]
            x[4] = x[4]

        # Calculate the Jacobian of the Dynamic Matrix A
        # see "Calculate the Jacobian of the Dynamic Matrix with respect to the state vector"
        a13 = float((x[3]/x[4]) * (np.cos(x[4]*dt+x[2]) - np.cos(x[2])))
        a14 = float((1.0/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a15 = float((dt*x[3]/x[4])*np.cos(x[4]*dt+x[2]) - (x[3]/x[4]**2)*(np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a23 = float((x[3]/x[4]) * (np.sin(x[4]*dt+x[2]) - np.sin(x[2])))
        a24 = float((1.0/x[4]) * (-np.cos(x[4]*dt+x[2]) + np.cos(x[2])))
        a25 = float((dt*x[3]/x[4])*np.sin(x[4]*dt+x[2]) - (x[3]/x[4]**2)*(-np.cos(x[4]*dt+x[2]) + np.cos(x[2])))
        JA = np.asarray([
            [1.0, 0.0, a13, a14, a15],
            [0.0, 1.0, a23, a24, a25],
            [0.0, 0.0, 1.0, 0.0, dt],
            [0.0, 0.0, 0.0, 1.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 1.0],
        ])

        # Project the error covariance ahead
        self.P = JA @ self.P @ JA.T + Q

        print(f"predict: x={x} P={np.diagonal(P)}")

    def update(self, z, R, H):
        """
        :param z: measurement vector
        :param R: measurement covariance matrix
        :param H: observation matrix
        """
        S = H @ self.P @ H.T + R
        K = self.P @ H.T @ np.linalg.inv(S)

        self.x = self.x + K @ (z - H @ self.x)
        I = np.eye(H.shape[1])
        self.P = (I - K @ H) @ self.P

        print(f"update: x={x} P={np.diagonal(P)}")

class NamedFilter:
    def __init__(self, unnamed_filter, **kwargs):
        x = {}
        P = defaultdict(lambda: defaultdict(lambda: 0.0))

        for k, v in kwargs.items():
            if k == "position":
                x["position_0"] = v["mean"][0]
                x["position_1"] = v["mean"][1]
                P["position_0"]["position_0"] = P["position_1"]["position_1"] = np.square(v["std"]) # = P["position_0"]["position_1"] = P["position_1"]["position_0"]
            else:
                x[k] = v["mean"]
                P[k][k] = np.square(v["std"])

        x = np.asarray([x[field] for field in unnamed_filter.fields])
        P = np.asarray([[P[field0][field1] for field1 in unnamed_filter.fields] for field0 in unnamed_filter.fields])

        self.unnamed_filter = unnamed_filter(x, P)

    def get_mean(self, field):
        if field == "position":
            return np.asarray([self.get_mean("position_0"), self.get_mean("position_1")])
        else:
            return self.unnamed_filter.x[self.unnamed_filter.fields.index(field)]

    def get_covariance(self, field0, field1):
        return self.unnamed_filter.P[self.unnamed_filter.fields.index(field0)][self.unnamed_filter.fields.index(field1)]

    def update(self, **kwargs):
        z = {}
        R = defaultdict(lambda: defaultdict(lambda: 0.0))

        if "axy" in kwargs:
            mean = kwargs["axy"]["mean"]
            covariance = kwargs["axy"]["covariance"]

            z["position_0"] = mean[1]
            z["position_1"] = mean[2]
            z["heading"] = mean[0]

            R["position_0"]["position_0"] = covariance[1, 1]
            R["position_1"]["position_1"] = covariance[2, 2]
            R["heading"]["heading"] = covariance[0, 0]
            R["position_1"]["heading"] = R["heading"]["position_1"] = covariance[2, 0]
            R["position_0"]["heading"] = R["heading"]["position_0"] = covariance[0, 1]
            R["position_0"]["position_1"] = R["position_1"]["position_0"] = covariance[1, 2]

            del kwargs["axy"]

        for k, v in kwargs.items():
            z[k] = v["mean"]
            R[k][k] = np.square(v["std"])

        observed_fields = [f for f in self.unnamed_filter.fields if f in z]

        z = np.asarray([z[field] for field in observed_fields])
        R = np.asarray([[R[field0][field1] for field1 in observed_fields] for field0 in observed_fields])
        H = np.asarray([[1 if field0 == field1 else 0 for field1 in self.unnamed_filter.fields] for field0 in observed_fields])

        self.unnamed_filter.update(z, R, H)

    def predict(self, dt, **kwargs):
        if "Q" in kwargs:
            assert len(kwargs) == 1
            Q = kwargs["Q"]
        else:
            Q = defaultdict(lambda: defaultdict(lambda: 0.0))

            for k, v in kwargs.items():
                if k == "position":
                    Q["position_0"]["position_0"] = Q["position_1"]["position_1"] = np.square(v) # Q["position_0"]["position_1"] = Q["position_1"]["position_0"] =
                else:
                    Q[k][k] = np.square(v)

            Q = np.asarray([[Q[field0][field1] for field1 in self.unnamed_filter.fields] for field0 in self.unnamed_filter.fields])

        self.unnamed_filter.predict(dt, Q)

def CTRV(**kwargs):
    return NamedFilter(unnamed_filter=UnnamedCTRV, **kwargs)

def CTRA(**kwargs):
    return NamedFilter(unnamed_filter=UnnamedCTRA, **kwargs)
