import shutil
import tempfile
import unittest
from pathlib import Path
from unittest.mock import Mock, patch

import numpy as np

from drone_base.stream.saving.frame_saver import BufferedFrameSaver


class TestBufferedFrameSaver(unittest.TestCase):
    def setUp(self):
        self.test_dir = tempfile.mkdtemp()
        self.output_dir = Path(self.test_dir) / "output"
        self.logger_dir = Path(self.test_dir) / "logs"

        self.test_frame = np.zeros((100, 100, 3), dtype=np.uint8)
        self.test_timestamp = 1234.567

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    def test_init_with_existing_directory(self):
        self.output_dir.mkdir(parents=True)
        saver = BufferedFrameSaver(self.output_dir)
        self.assertTrue(saver.output_dir.exists())

    @patch('drone_base.config.logger.LoggerSetup.setup_logger')
    def test_logger_initialization(self, mock_setup_logger):
        """Check logger initialization with default parameters"""
        mock_logger = Mock()
        mock_setup_logger.return_value = mock_logger

        saver = BufferedFrameSaver(self.output_dir)
        mock_setup_logger.assert_called_once_with(
            logger_name='BufferedFrameSaver',
            log_file=None
        )
        self.assertEqual(saver.logger, mock_logger)

        mock_setup_logger.reset_mock()

        self.logger_dir.mkdir(parents=True)
        expected_log_file = self.logger_dir / "BufferedFrameSaver.log"

        _ = BufferedFrameSaver(self.output_dir, logger_dir=self.logger_dir)
        mock_setup_logger.assert_called_once_with(
            logger_name='BufferedFrameSaver',
            log_file=expected_log_file
        )

    def test_add_frame(self):
        saver = BufferedFrameSaver(self.output_dir)

        saver.add_frame(self.test_frame, self.test_timestamp)
        self.assertEqual(len(saver.frames), 1)

        frame, timestamp = saver.frames[0]
        self.assertTrue(np.array_equal(frame, self.test_frame))
        self.assertEqual(timestamp, self.test_timestamp)

    def test_add_multiple_frames(self):
        saver = BufferedFrameSaver(self.output_dir)

        num_frames = 5
        for i in range(num_frames):
            saver.add_frame(self.test_frame, self.test_timestamp + i)

        self.assertEqual(len(saver.frames), num_frames)

    @patch("cv2.imwrite")
    def test_save_all(self, mock_imwrite):
        saver = BufferedFrameSaver(self.output_dir)
        num_frames = 3
        for i in range(num_frames):
            saver.add_frame(self.test_frame, self.test_timestamp + i)

        saver.save_all()
        self.assertEqual(mock_imwrite.call_count, num_frames)
        self.assertEqual(len(saver.frames), 0)

    @patch("cv2.imwrite")
    def test_save_all_correct_filenames(self, mock_imwrite):
        saver = BufferedFrameSaver(self.output_dir)

        timestamp = 1.234
        saver.add_frame(self.test_frame, timestamp)

        saver.save_all()
        expected_filename = str(self.output_dir / "frame_000000_1234.png")
        mock_imwrite.assert_called_once_with(expected_filename, self.test_frame)


if __name__ == "__main__":
    unittest.main()
