from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm

from car_mask_splitter import CarMaskSplitter

ESCAPE_KEY = 27


def load_image(path: Path, grayscale: bool = False) -> np.ndarray:
    flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
    return cv2.imread(str(path), flag)


def get_frame_paths(base_path: Path, scene_name: str) -> dict:
    """Get paths to image, segmentation, front and back label directories."""
    return {
        "images": base_path / "images" / scene_name,
        "segmented": base_path / "segmented" / scene_name,
        "front": base_path / "labels" / scene_name / "front",
        "back": base_path / "labels" / scene_name / "back",
    }


def stack_frame(original: np.ndarray, segmentation: np.ndarray, front: np.ndarray, back: np.ndarray) -> np.ndarray:
    """Construct a horizontal composite image with segmentation and front/back channels."""
    segmentation_color = cv2.cvtColor(segmentation, cv2.COLOR_GRAY2BGR)
    composite = np.zeros_like(original)
    composite[..., 0] = back
    composite[..., 2] = front
    return np.hstack((original, segmentation_color, composite))


def read_frame_set(frame_name: str, paths: dict) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Read original, segmentation, front, and back frames."""
    return (
        load_image(paths["images"] / frame_name),
        load_image(paths["segmented"] / frame_name, grayscale=True),
        load_image(paths["front"] / frame_name, grayscale=True),
        load_image(paths["back"] / frame_name, grayscale=True),
    )


def play_video(dataset_path: Path, scene_name: str, fps: int = 10) -> None:
    paths = get_frame_paths(dataset_path, scene_name)
    frame_paths = sorted(paths["images"].glob("*.png"))

    for frame_path in tqdm(frame_paths, desc="Playing video"):
        frame_name = frame_path.name
        original, segmentation, front, back = read_frame_set(frame_name, paths)

        if original is None or segmentation is None or front is None or back is None:
            print(f"[Warning] Missing frame data for: {frame_name}")
            continue

        combined = stack_frame(original, segmentation, front, back)
        cv2.imshow("Visualization - [Original | Segmentation | Front & Back]", combined)

        if cv2.waitKey(int(1000 / fps)) == ESCAPE_KEY:
            break

    cv2.destroyAllWindows()


def save_video(dataset_path: Path, scene_name: str, output_path: Path, fps: int = 10) -> None:
    paths = get_frame_paths(dataset_path, scene_name)
    frame_paths = sorted(paths["images"].glob("*.png"))

    first_frame = load_image(frame_paths[0])
    height, width = first_frame.shape[:2]
    frame_size = (width * 3, height)

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(output_path), fourcc, fps, frame_size)

    for frame_path in tqdm(frame_paths, desc="Saving video"):
        frame_name = frame_path.name
        original, segmentation, front, back = read_frame_set(frame_name, paths)

        if original is None or segmentation is None or front is None or back is None:
            print(f"[Warning] Skipping incomplete frame: {frame_name}")
            continue

        combined = stack_frame(original, segmentation, front, back)
        writer.write(combined)

    writer.release()
    print(f"[Info] Video saved to: {output_path}")


def save_overlay_video(dataset_path: Path, scene_name: str, output_path: Path, fps: int = 10) -> None:
    images_dir = dataset_path / "images" / scene_name
    front_dir = dataset_path / "labels" / scene_name / "front"
    back_dir = dataset_path / "labels" / scene_name / "back"

    frame_paths = sorted(images_dir.glob("*.png"))
    annotator = CarMaskSplitter()

    sample_frame = load_image(frame_paths[0])
    height, width = sample_frame.shape[:2]

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))

    for frame_path in tqdm(frame_paths, desc="Creating overlay video"):
        frame_name = frame_path.name
        frame = load_image(frame_path)
        front_mask = load_image(front_dir / frame_name, grayscale=True)
        back_mask = load_image(back_dir / frame_name, grayscale=True)

        if frame is None or front_mask is None or back_mask is None:
            print(f"[Warning] Skipping frame with missing masks: {frame_name}")
            continue

        overlaid = annotator.overlay(frame, front_mask, back_mask)
        writer.write(overlaid)

    writer.release()
    print(f"[Info] Overlay video saved to: {output_path}")


if __name__ == "__main__":
    dataset_path = Path("/home/user/Desktop/work/data/car-follow/sim")
    scene_name = "around-car-30-45-60-75-90-high-quality"
    fps = 15
    is_video_play = True

    print(f"[Start] Processing scene: {scene_name}")
    if is_video_play:
        play_video(dataset_path, scene_name, fps=fps)

    save_video_path = Path(f"./{scene_name}_visualization.mp4")
    save_video(dataset_path, scene_name, save_video_path, fps=fps)

    overlay_output_path = Path(f"./{scene_name}_visualization_overlay.mp4")
    save_overlay_video(dataset_path, scene_name, overlay_output_path, fps=fps)
    print(f"[Done] All processing completed for scene: {scene_name}")
