from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm

from path_manager import get_path_of_files, Paths
from yolo_model import YoloSegmentation


def generate_segmentation_masks_for_dataset(
        data_path: str | Path, model_path: str | Path = Paths.YOLO_SEGMENTATION_MODEL_V2_PATH
):
    """
    Generates segmentation masks for a given dataset.
    """
    images_dirs = data_path / "images"
    save_dir = data_path / "segmented"
    scenes_dirs = [item for item in images_dirs.iterdir() if item.is_dir()]
    yolo_model = YoloSegmentation(model_path=model_path)

    for index, scene_dir in enumerate(scenes_dirs):
        scene_name = scene_dir.name
        segmented_scene_dir = save_dir / scene_name
        segmented_scene_dir.mkdir(parents=True, exist_ok=True)
        frame_paths_per_curr_scene = [Path(p) for p in get_path_of_files(scene_dir, ".png")]
        print(f"\nGenerating segmentation masks for scene {scene_name}, Scene {index + 1}/{len(scenes_dirs)}\n")

        for frame_path in tqdm(frame_paths_per_curr_scene, desc=f"Processing frames in {scene_name}", leave=False):
            frame_name = frame_path.name
            frame = cv2.imread(str(frame_path))
            if "just-environment" in scene_dir.name:
                frame_height, frame_width = frame.shape[:2]
                empty_frame = np.zeros((frame_height, frame_width), dtype=np.uint8)
                cv2.imwrite(str(segmented_scene_dir / frame_name), empty_frame)
                continue

            _, mask = yolo_model.segment_image(frame=frame)
            cv2.imwrite(str(segmented_scene_dir / frame_name), mask)


if __name__ == '__main__':
    dataset_path = Path("/home/user/Desktop/work/data/car-follow/sim")
    generate_segmentation_masks_for_dataset(dataset_path)
