from pathlib import Path

import cv2
import numpy as np

from car_mask_splitter import CarMaskSplitter
from path_manager import get_path_of_files


def generate_front_back_masks_for_dataset(data_path: str | Path):
    """
    Generates segmentation masks (front and back of the objects) for a given dataset containing frames and
    segmentation masks.

    It will ask user to choose which part is the front of the objects.
    """
    images_dirs = data_path / "images"
    mask_input_dir = data_path / "segmented"
    labels_dir = data_path / "labels"
    annotator = CarMaskSplitter()

    scenes_dirs = [item for item in images_dirs.iterdir() if item.is_dir()]
    print(scenes_dirs)
    for index, scene_dir in enumerate(scenes_dirs):
        scene_name = scene_dir.name
        mask_scene_dir = mask_input_dir / scene_name
        labels_front_dir = labels_dir / scene_name / "front"
        labels_back_dir = labels_dir / scene_name / "back"

        labels_front_dir.mkdir(parents=True, exist_ok=True)
        labels_back_dir.mkdir(parents=True, exist_ok=True)

        frame_paths_per_scene = [Path(p) for p in get_path_of_files(scene_dir, ".png")]
        masks_paths_per_scene = [Path(p) for p in get_path_of_files(mask_scene_dir, ".png")]
        print(f"\nAnnotating front/back masks for scene {scene_name}, Scene {index + 1}/{len(scenes_dirs)}\n")
        print(f"Total frames: {len(frame_paths_per_scene)} | Total masks: {len(masks_paths_per_scene)}")
        if len(frame_paths_per_scene) != len(masks_paths_per_scene):
            print("Inconsistent number of frames and masks, skipping scene.")
            continue

        for frame_path, mask_path in zip(frame_paths_per_scene, masks_paths_per_scene):
            frame_name = frame_path.name

            front_mask_save_path = labels_front_dir / frame_name
            back_mask_save_path = labels_back_dir / frame_name
            if front_mask_save_path.exists() and back_mask_save_path.exists():
                continue

            frame = cv2.imread(str(frame_path))
            if not mask_path.exists():
                print(f"Segmentation mask not found for: {frame_name}, skipping.")
                continue

            mask = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)
            if "just-environment" in scene_name:
                frame_height, frame_width = frame.shape[:2]
                empty_frame = np.zeros((frame_height, frame_width), dtype=np.uint8)
                cv2.imwrite(str(front_mask_save_path), empty_frame)
                cv2.imwrite(str(back_mask_save_path), empty_frame)
                continue

            if mask is None or np.max(mask) == 0:
                print(f"Empty or invalid mask for: {frame_name}, You should skip.")

            front_mask, back_mask = annotator.annotate(image=frame, mask=mask, frame_name=frame_name)
            cv2.imwrite(str(front_mask_save_path), front_mask)
            cv2.imwrite(str(back_mask_save_path), back_mask)


if __name__ == '__main__':
    dataset_path = Path("/home/user/Desktop/work/data/car-follow/validation")
    generate_front_back_masks_for_dataset(dataset_path)
