import json
from pathlib import Path

import numpy as np

from drone_base.stream.processing.formulas import rotate_velocity, get_forward_vector, haversine_distance, \
    quat_to_rotation_matrix


def get_data(file_path: str | Path) -> list[dict]:
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data


def get_data_until(file_path: str | Path, until_time: float) -> list[dict]:
    """Reads JSON data from a file and returns entries up to the specified time."""
    with open(file_path, 'r') as file:
        data = json.load(file)
    filtered_data = [entry for entry in data if entry["time"] <= until_time]

    return filtered_data


def save_data(file_path: str | Path, data: dict | list[dict]):
    with open(file_path, "w") as file:
        json.dump(data, file)


def calculate_distance(logs: list[dict]) -> float:
    """Compute translation (horizontal components) distance from metadata. Skips if drone is stationary"""
    total_distance = 0.0
    prev_time = None

    for entry in logs:
        drone_data = entry['drone']
        quat = drone_data['quat']
        velocity = drone_data['speed']
        time = entry['time']

        if velocity['north'] == 0.0 and velocity['east'] == 0.0 and velocity['down'] == 0.0:
            continue

        global_velocity = rotate_velocity(quat, velocity)
        if prev_time is not None:
            time_diff = time - prev_time
            distance = np.linalg.norm(global_velocity[:2]) * time_diff
            total_distance += distance

        prev_time = time

    return total_distance


def calculate_forward_distance(data: list[dict]) -> float:
    """
    Calculate the total forward distance traveled by the drone.

    Returns the total forward distance in meters
    """
    min_data_points_to_compute = 2
    if len(data) < min_data_points_to_compute:
        return 0.0

    total_forward_distance = 0.0

    for i in range(1, len(data)):
        prev_pos = data[i - 1]['drone']['local_position']
        curr_pos = data[i]['drone']['local_position']

        displacement = np.array([
            curr_pos['x'] - prev_pos['x'],
            curr_pos['y'] - prev_pos['y'],
            curr_pos['z'] - prev_pos['z']
        ])

        forward = get_forward_vector(data[i]['drone']['quat'])

        # Project displacement onto forward direction
        forward_displacement = np.dot(displacement, forward)

        if forward_displacement > 0:
            total_forward_distance += forward_displacement

    return total_forward_distance


def compute_forward_distance_hybrid(data: list[dict], distance_threshold: float = 0.001, window_size: int = 5) -> float:
    """
    Compute the total distance traveled using both GPS and orientation data.

    :param data: The data from the drone logs.
    :param distance_threshold: Minimum distance to consider in meters.
    :param window_size: How many distances to consider.
    :return: The total distance traveled over the forward displacement.
    """
    total_distance = 0

    distance_window = []

    for i in range(1, len(data)):
        prev_pos = data[i - 1]['drone']['location']
        curr_pos = data[i]['drone']['location']

        distance = haversine_distance(
            prev_pos['latitude'], prev_pos['longitude'],
            curr_pos['latitude'], curr_pos['longitude']
        )

        if distance < distance_threshold:
            continue

        quat = data[i]['drone']['quat']
        rotation_matrix = quat_to_rotation_matrix(quat).as_matrix()

        # Calculate movement vector in NED frame
        dlat = curr_pos['latitude'] - prev_pos['latitude']
        dlon = curr_pos['longitude'] - prev_pos['longitude']
        movement_ned = np.array([dlat, dlon, 0])  # Ignoring altitude changes

        # Transform to body frame
        movement_body = rotation_matrix.T @ movement_ned

        # Forward component is x in body frame
        forward_ratio = movement_body[0] / np.linalg.norm(movement_body) if np.linalg.norm(movement_body) > 0 else 0
        forward_distance = distance * max(0, forward_ratio)  # Only count forward motion

        # Apply moving average
        distance_window.append(forward_distance)
        if len(distance_window) > window_size:
            distance_window.pop(0)
        smoothed_distance = np.mean(distance_window)

        if data[i]['drone']['flying_state'] != 'FS_LANDED':
            total_distance += smoothed_distance

    return total_distance
