from pathlib import Path

import cv2
import numpy as np

from drone_base.config.logger import LoggerSetup


class BufferedFrameSaver:
    def __init__(self, output_dir: str | Path, logger_dir: str | Path | None = None, save_extension: str = "png"):
        self.output_dir = Path(output_dir)
        self.save_extension = save_extension
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.frames: list[tuple[np.ndarray, float]] = []

        if logger_dir is not None:
            logger_dir = Path(logger_dir) / f"{self.__class__.__name__}.log"
        self.logger = LoggerSetup.setup_logger(logger_name=self.__class__.__name__, log_file=logger_dir)

    def add_frame(self, frame: np.ndarray, timestamp: float):
        """Add a frame to the buffer with its timestamp."""
        self.frames.append((frame, timestamp))

    def save_all(self):
        """Save all frames in the buffer to disk."""
        self.logger.info("Saving all frames to disk")
        for index, (frame, timestamp) in enumerate(self.frames):
            timestamp_ms = int(timestamp * 1000)
            output_path = self.output_dir / f"frame_{index:06d}_{timestamp_ms}.{self.save_extension}"
            cv2.imwrite(str(output_path), frame)
            if index % 100 == 0:
                self.logger.info("Saved %s/%s frames", f"{index:06d}", len(self.frames))
        self.logger.info("Saved all frames to disk -> %s", self.output_dir)
        self.frames.clear()
