import json
import os
import tempfile
import unittest
from unittest.mock import patch

import numpy as np

from drone_base.stream.processing.streaming_metadata import get_data, get_data_until, save_data, \
    calculate_forward_distance


class TestStreamingMetadata(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures before each test."""
        self.sample_data = [
            {
                "time": 1.0,
                "drone": {
                    "quat": {"w": 1.0, "x": 0.0, "y": 0.0, "z": 0.0},
                    "speed": {"north": 1.0, "east": 2.0, "down": 0.0},
                    "local_position": {"x": 0.0, "y": 0.0, "z": 0.0},
                    "location": {"latitude": 37.7749, "longitude": -122.4194},
                    "flying_state": "FS_FLYING"
                }
            },
            {
                "time": 2.0,
                "drone": {
                    "quat": {"w": 0.707, "x": 0.0, "y": 0.0, "z": 0.707},
                    "speed": {"north": 2.0, "east": 1.0, "down": 0.0},
                    "local_position": {"x": 1.0, "y": 2.0, "z": 0.0},
                    "location": {"latitude": 37.7750, "longitude": -122.4195},
                    "flying_state": "FS_FLYING"
                }
            },
            {
                "time": 3.0,
                "drone": {
                    "quat": {"w": 0.0, "x": 0.0, "y": 0.0, "z": 1.0},
                    "speed": {"north": 0.0, "east": 0.0, "down": 0.0},
                    "local_position": {"x": 2.0, "y": 3.0, "z": 0.0},
                    "location": {"latitude": 37.7751, "longitude": -122.4196},
                    "flying_state": "FS_LANDED"
                }
            }
        ]

        self.temp_dir = tempfile.TemporaryDirectory()
        self.temp_file_path = os.path.join(self.temp_dir.name, "test_data.json")
        with open(self.temp_file_path, "w") as file:
            json.dump(self.sample_data, file)

    def tearDown(self):
        """Clean up after each test."""
        self.temp_dir.cleanup()

    def test_get_data(self):
        """Test get_data function."""
        result = get_data(self.temp_file_path)
        self.assertEqual(result, self.sample_data)
        self.assertEqual(len(result), 3)

    def test_get_data_until(self):
        """Test get_data_until function."""
        result = get_data_until(self.temp_file_path, 2.0)
        self.assertEqual(len(result), 2)
        self.assertEqual(result[0]["time"], 1.0)
        self.assertEqual(result[1]["time"], 2.0)

        result = get_data_until(self.temp_file_path, 1.5)
        self.assertEqual(len(result), 1)
        self.assertEqual(result[0]["time"], 1.0)

        result = get_data_until(self.temp_file_path, 0.5)
        self.assertEqual(len(result), 0)

    def test_save_data(self):
        """Test save_data function."""
        test_data = {"test": "data"}
        test_file = os.path.join(self.temp_dir.name, "save_test.json")
        save_data(test_file, test_data)
        with open(test_file, "r") as f:
            loaded_data = json.load(f)

        self.assertEqual(loaded_data, test_data)

        test_list_data = [{"id": 1}, {"id": 2}]
        save_data(test_file, test_list_data)
        with open(test_file, "r") as f:
            loaded_list_data = json.load(f)

        self.assertEqual(loaded_list_data, test_list_data)

    @patch("drone_base.stream.processing.streaming_metadata.get_forward_vector")
    def test_calculate_forward_distance(self, mock_get_forward_vector):
        """Test calculate_forward_distance function."""
        mock_get_forward_vector.return_value = np.array([1.0, 0.0, 0.0])

        result = calculate_forward_distance(self.sample_data)

        # For the first transition, displacement is [1, 2, 0], forward is [1, 0, 0]
        # dot product = 1.0
        # For the second transition, displacement is [1, 1, 0], forward is [1, 0, 0]
        # dot product = 1.0
        # Total = 2.0
        self.assertEqual(result, 2.0)

        short_data = [self.sample_data[0]]
        result = calculate_forward_distance(short_data)
        self.assertEqual(result, 0.0)

        mock_get_forward_vector.return_value = np.array([-1.0, 0.0, 0.0])
        result = calculate_forward_distance(self.sample_data)
        self.assertEqual(result, 0.0)


if __name__ == "__main__":
    unittest.main()
