import os
import re
from collections import defaultdict
from glob import glob
from pathlib import Path

import cv2
import numpy as np

from drone_base.config.logger import LoggerSetup


class CreateVideoFromFrames:
    """
    This class creates a video from a frame folder.
    It assumes that the images were generated by BufferedFrameSaver, or they have the timestamp metadata in their names.
    """

    def __init__(self, input_dir: str | Path, output_path: str | Path = "output.mp4", frame_type: str = "png"):
        self.logger = LoggerSetup.setup_logger(logger_name=self.__class__.__name__)
        self.input_dir = input_dir
        self.output_path = output_path
        self.frame_type = frame_type
        self.frame_files = []
        self.timestamps = []

    def __populate_frame_paths(self) -> None:
        """Get and sort frame files from the input directory."""
        self.frame_files = glob(os.path.join(self.input_dir, f"frame_*.{self.frame_type}"))
        if len(self.frame_files) == 0:
            self.logger.error("No frames found in the input directory: %s.", self.input_dir)
            raise FileNotFoundError("No frames found! Check file type.")

        self.frame_files.sort(key=lambda x: int(re.search(r"frame_(\d+)_", x).group(1)))
        self.timestamps = [int(re.search(fr"_(\d+)\.{self.frame_type}$", f).group(1)) for f in self.frame_files]
        self.logger.info("Found %s frames", len(self.frame_files))

    def __compute_fps(self) -> tuple[float, float]:
        """Calculate the actual FPS based on frame timestamps and the total duration of the video."""
        duration_ms = max(self.timestamps) - min(self.timestamps)
        actual_fps = (len(self.frame_files)) / (duration_ms / 1000)

        self.logger.info("Sequence duration: %.2f seconds", duration_ms / 1000)
        self.logger.info("Calculated FPS: %.2f", actual_fps)

        return actual_fps, duration_ms

    def __analyze_frame_timing(self):
        """Analyze the timing between frames in detail."""
        if not self.timestamps:
            return

        time_diffs = []
        for i in range(len(self.timestamps) - 1):
            diff = (self.timestamps[i + 1] - self.timestamps[i]) / 1000.0  # Convert to seconds
            time_diffs.append(diff)

        avg_diff = np.mean(time_diffs)
        median_diff = np.median(time_diffs)
        std_diff = np.std(time_diffs)
        min_diff = min(time_diffs)
        max_diff = max(time_diffs)

        ranges = defaultdict(int)
        for diff in time_diffs:
            range_key = int(diff * 1000)
            ranges[range_key] += 1

        self.logger.info("Frame Timing Analysis:")
        self.logger.info("Average time between frames: %.3f s (%.2f fps)", avg_diff, 1 / avg_diff)
        self.logger.info("Median time between frames: %.3f s (%.2f fps)", median_diff, 1 / median_diff)
        self.logger.info("Standard deviation: %.3f s", std_diff)
        try:
            self.logger.info("Min time between frames: %.3f s (%.2f fps)", min_diff, 1 / min_diff)
            self.logger.info("Max time between frames: %.3f s (%.2f fps)", max_diff, 1 / max_diff)
        except ZeroDivisionError:
            self.logger.info("Division by zero, skipping min and max fps calculation.")

        self.logger.info("Frame timing distribution:")
        for range_key, count in sorted(ranges.items()):
            percentage = (count / len(time_diffs)) * 100
            self.logger.info("%s ms: %s frames (%.1f %%)", range_key, count, percentage)

        try:
            return {
                'avg_fps': 1 / avg_diff,
                'min_fps': 1 / max_diff,
                'max_fps': 1 / min_diff,
                'std_dev': std_diff
            }
        except ZeroDivisionError:
            self.logger.info("Division by zero, skipping fps calculation.")
            return {
                'avg_fps': 0,
                'min_fps': 0,
                'max_fps': 0,
                'std_dev': 0
            }

    def create_video(self):
        self.__populate_frame_paths()
        fps, duration_ms = self.__compute_fps()
        _ = self.__analyze_frame_timing()

        first_frame = cv2.imread(self.frame_files[0])
        if first_frame is None:
            self.logger.error("Could not read the first frame")
            raise ValueError("Invalid first frame.")
        height, width, layers = first_frame.shape
        self.logger.info("Frame dimensions: %s x %s", width, height)

        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_writer = cv2.VideoWriter(self.output_path, fourcc, fps, (width, height))

        for index, frame_file in enumerate(self.frame_files):
            frame = cv2.imread(frame_file)
            if frame is not None:
                video_writer.write(frame)
                if index % 100 == 0:
                    self.logger.debug(f"Written {index:06d}/{len(self.frame_files):06d}")  # noqa: G004
            else:
                self.logger.warning("Could not read frame: %s", frame_file)

        video_writer.release()
        self.logger.info("Video created successfully at %s", self.output_path)
        self.logger.info("Original sequence duration: %.2f seconds", duration_ms / 1000)
        self.logger.info("Video duration at %.2f fps: %.2f seconds", fps, len(self.frame_files) / fps)


def create_video(input_dir: str | Path, output_path: str | Path = "output.mp4", frame_type: str = "png"):
    video_saver = CreateVideoFromFrames(
        input_dir=input_dir,
        output_path=output_path,
        frame_type=frame_type,
    )
    video_saver.create_video()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Create a video from a frame folder.")
    parser.add_argument(
        "--input_dir",
        default="./examples/results/2025-04-01_12-54-59/frames",
        type=str,
        help="Path to the input directory containing frames"
    )
    parser.add_argument("--output_path", default="./output.mp4", type=str, help="Path to save the output video")
    parser.add_argument("--frame_type", default="png", type=str, help="Frame type (e.g., png, jpg)")
    args = parser.parse_args()

    create_video(
        input_dir=args.input_dir,
        output_path=args.output_path,
        frame_type=args.frame_type
    )
