import unittest

import numpy as np
from math import pi

from drone_base.stream.processing.formulas import quat_to_rotation_matrix, rotate_velocity, quaternion_to_euler, \
    get_forward_vector, haversine_distance


class TestFormulas(unittest.TestCase):
    def setUp(self):
        # Identity quaternion (no rotation)
        self.identity_quat = {'w': 1.0, 'x': 0.0, 'y': 0.0, 'z': 0.0}

        # 90-degree rotation around z-axis
        # This rotates +x to +y direction
        self.z90_quat = {
            'w': 0.7071067811865476,  # cos(π/4)
            'x': 0.0,
            'y': 0.0,
            'z': 0.7071067811865475  # sin(π/4)
        }

        # 90-degree rotation around x-axis
        self.x90_quat = {
            'w': 0.7071067811865476,
            'x': 0.7071067811865475,
            'y': 0.0,
            'z': 0.0
        }

        self.velocity = {'north': 10.0, 'east': 5.0, 'down': -2.0}

        # NED coordinate system values
        self.north_vector = np.array([1.0, 0.0, 0.0])
        self.east_vector = np.array([0.0, 1.0, 0.0])
        self.down_vector = np.array([0.0, 0.0, 1.0])

    def test_quat_to_rotation_matrix(self):
        # Test identity quaternion
        rot_matrix = quat_to_rotation_matrix(self.identity_quat)
        np.testing.assert_allclose(rot_matrix.as_matrix(), np.eye(3), rtol=1e-6)

        # Test 90-degree rotation around z-axis
        rot_matrix = quat_to_rotation_matrix(self.z90_quat)
        # Applying this to [1,0,0] should give [0,1,0]
        rotated = rot_matrix.apply(self.north_vector)
        # Use higher atol for floating point issues
        np.testing.assert_allclose(rotated, self.east_vector, rtol=1e-6, atol=1e-15)

        # Test 90-degree rotation around x-axis
        rot_matrix = quat_to_rotation_matrix(self.x90_quat)
        # Applying this to [0,1,0] should give [0,0,1]
        rotated = rot_matrix.apply(self.east_vector)
        np.testing.assert_allclose(rotated, self.down_vector, rtol=1e-6, atol=1e-15)

        # Test normalization - a non-normalized quaternion should be handled correctly
        non_normalized_quat = {'w': 2.0, 'x': 0.0, 'y': 0.0, 'z': 0.0}
        rot_matrix = quat_to_rotation_matrix(non_normalized_quat)
        np.testing.assert_allclose(rot_matrix.as_matrix(), np.eye(3), rtol=1e-6)

    def test_rotate_velocity(self):
        # Test with identity quaternion (should be unchanged)
        rotated_vel = rotate_velocity(self.identity_quat, self.velocity)
        expected = np.array([self.velocity['north'], self.velocity['east'], self.velocity['down']])
        np.testing.assert_allclose(rotated_vel, expected, rtol=1e-6)

        # Test with 90-degree z rotation
        # [10, 5, -2] rotated 90 degrees around z should be [-5, 10, -2]
        rotated_vel = rotate_velocity(self.z90_quat, self.velocity)
        expected = np.array([-5.0, 10.0, -2.0])
        np.testing.assert_allclose(rotated_vel, expected, rtol=1e-6)

        # Test with 90-degree x rotation
        # [10, 5, -2] rotated 90 degrees around x should be [10, -2, -5]
        rotated_vel = rotate_velocity(self.x90_quat, self.velocity)
        expected = np.array([10.0, 2.0, 5.0])  # Signs adjusted due to NED coordinate system
        np.testing.assert_allclose(rotated_vel, expected, rtol=1e-6)

    def test_quaternion_to_euler(self):
        # Test identity quaternion (should give zero angles)
        euler = quaternion_to_euler(self.identity_quat)
        expected = np.array([0.0, 0.0, 0.0])
        np.testing.assert_allclose(euler, expected, rtol=1e-6)

        # Test 90-degree rotation around z-axis (yaw = 90 degrees)
        euler = quaternion_to_euler(self.z90_quat)
        expected = np.array([0.0, 0.0, pi / 2])  # [0, 0, π/2]
        np.testing.assert_allclose(euler, expected, rtol=1e-6)

        # Test 90-degree rotation around x-axis (roll = 90 degrees)
        euler = quaternion_to_euler(self.x90_quat)
        expected = np.array([pi / 2, 0.0, 0.0])  # [π/2, 0, 0]
        np.testing.assert_allclose(euler, expected, rtol=1e-6)

        # Test combined rotation
        combined_quat = {
            'w': 0.5,
            'x': 0.5,
            'y': 0.5,
            'z': 0.5
        }
        euler = quaternion_to_euler(combined_quat)
        # This should be approximately [π/3, π/3, π/3]
        self.assertEqual(len(euler), 3)  # Should return three angles

        # Test gimbal lock case (pitch = 90 degrees)
        gimbal_lock_quat = {
            'w': 0.7071067811865476,
            'y': 0.7071067811865475,
            'x': 0.0,
            'z': 0.0
        }
        euler = quaternion_to_euler(gimbal_lock_quat)
        expected = np.array([0.0, pi / 2, 0.0])  # [0, π/2, 0]
        np.testing.assert_allclose(euler, expected, rtol=1e-6)

    def test_get_forward_vector(self):
        # Test identity quaternion (forward should be [1,0,0])
        forward = get_forward_vector(self.identity_quat)
        expected = np.array([1.0, 0.0, 0.0])
        np.testing.assert_allclose(forward, expected, rtol=1e-6)

        # Test 90-degree rotation around z-axis
        # Forward should be [0,1,0]
        forward = get_forward_vector(self.z90_quat)
        expected = np.array([0.0, 1.0, 0.0])
        # Use higher atol for floating point issues
        np.testing.assert_allclose(forward, expected, rtol=1e-6, atol=1e-15)

        # Test 180-degree rotation around z-axis
        z180_quat = {
            'w': 0.0,
            'x': 0.0,
            'y': 0.0,
            'z': 1.0
        }
        forward = get_forward_vector(z180_quat)
        expected = np.array([-1.0, 0.0, 0.0])
        np.testing.assert_allclose(forward, expected, rtol=1e-6, atol=1e-15)

        # Test that x and y rotations don't affect the forward vector's z component (assuming level flight)
        x45_quat = {
            'w': 0.9238795325112867,  # cos(π/8)
            'x': 0.3826834323650898,  # sin(π/8)
            'y': 0.0,
            'z': 0.0
        }
        forward = get_forward_vector(x45_quat)
        self.assertAlmostEqual(forward[2], 0.0)

    def test_haversine_distance(self):
        # Test same point (should be zero)
        dist = haversine_distance(45.0, -75.0, 45.0, -75.0)
        self.assertAlmostEqual(dist, 0.0, places=5)

        # Test poles
        dist = haversine_distance(90.0, 0.0, -90.0, 0.0)
        self.assertAlmostEqual(dist, 20015086.796, places=-3)  # Half Earth's circumference

        # Test equator, 90 degrees apart
        dist = haversine_distance(0.0, 0.0, 0.0, 90.0)
        self.assertAlmostEqual(dist, 10007543.398, places=-3)  # Quarter Earth's circumference

        # Test tiny distance (few meters)
        lat1, lon1 = 40.748817, -73.985428  # Times Square
        lat2, lon2 = 40.748801, -73.985407  # Few meters away
        dist = haversine_distance(lat1, lon1, lat2, lon2)
        self.assertLess(dist, 5.0)  # Should be less than 5 meters


if __name__ == '__main__':
    unittest.main()
