import subprocess
from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm

from nn.infer import MaskSplitterInference

ESCAPE_KEY = 27


def load_image(path: Path, grayscale: bool = False) -> np.ndarray:
    flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
    return cv2.imread(str(path), flag)


def get_frame_paths(base_path: Path, scene_name: str) -> dict:
    return {
        "images": base_path / "images" / scene_name,
        "segmented": base_path / "segmented" / scene_name,
        "labels": base_path / "labels" / scene_name,
    }


def stack_frame(original: np.ndarray, segmentation: np.ndarray, front: np.ndarray, back: np.ndarray) -> np.ndarray:
    """Construct a horizontal composite image with predicted front/back channels."""
    color_mask = colorize_masks(front, back)
    return np.hstack((original, segmentation, color_mask))


def colorize_masks(front: np.ndarray, back: np.ndarray) -> np.ndarray:
    """
    Convert front and back binary masks into a colored RGB image.
        - Front mask: RED (255, 0, 0)
        - Back mask: BLUE (0, 0, 255)
    """
    h, w = front.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    color_mask[front == 255] = (0, 0, 255)
    color_mask[back == 255] = (255, 0, 0)

    return color_mask

def color_mask(mask: np.ndarray) -> np.ndarray:
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    color_mask[mask == 255] = (0, 255, 0)
    return color_mask

def play_inference_vs_ground_truth_video(
        dataset_path: str | Path,
        scene_name: str,
        model_path: str | Path,
        fps: int = 10,
        output_path: str | Path | None = None
):
    mask_splitter = MaskSplitterInference(model_path=model_path)

    paths = get_frame_paths(dataset_path, scene_name)
    frame_paths = sorted(paths["images"].glob("*.png"))
    video_writer = None
    if output_path is not None:
        first_image = load_image(frame_paths[0])
        segmentation = load_image(paths["segmented"] / frame_paths[0].name, grayscale=True)
        front_mask, back_mask = mask_splitter.infer(image=first_image, mask=segmentation)
        dummy_combined = stack_frame(first_image,
                                     colorize_masks(front_mask.astype(np.uint8), back_mask.astype(np.uint8)),
                                     front_mask.astype(np.uint8), back_mask.astype(np.uint8))
        height, width, _ = dummy_combined.shape

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 'mp4v' or 'avc1'
        video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
    # frame_paths = frame_paths[2000:]
    # idx = 0
    for frame_path in tqdm(frame_paths, desc="Inference video"):
        frame_name = frame_path.name
        # print(f"{idx:04d} -> {frame_name}")
        image = load_image(paths["images"] / frame_name)
        segmentation = load_image(paths["segmented"] / frame_name, grayscale=True)

        if image is None or segmentation is None:
            print(f"[Warning] Missing data for: {frame_name}")
            continue

        front_mask, back_mask = mask_splitter.infer(image=image, mask=segmentation)
        front = load_image(paths["labels"] / "front" / frame_name, grayscale=True)
        back = load_image(paths["labels"] / "back" / frame_name, grayscale=True)
        segmentation_true = colorize_masks(front, back)
        combined = stack_frame(image, segmentation_true, front_mask.astype(np.uint8), back_mask.astype(np.uint8))

        if video_writer:
            video_writer.write(combined)
        else:
            cv2.imshow("Inference - [Original | Ground Truth | Predicted Front & Back]", combined)

        # idx += 1
        # key = cv2.waitKey(0) & 0xFF
        # if key == ESCAPE_KEY:
        #     break
        if cv2.waitKey(int(1000 / fps)) == ESCAPE_KEY:
            break

    if video_writer:
        video_writer.release()
    cv2.destroyAllWindows()


def run_inference_video(
        dataset_path: str | Path,
        scene_name: str,
        model_path: str | Path,
        fps: int = 10,
        output_path: str | Path | None = None
):
    mask_splitter = MaskSplitterInference(model_path=model_path)
    paths = get_frame_paths(dataset_path, scene_name)
    frame_paths = sorted(paths["images"].glob("*.png"))
    video_writer = None
    if output_path is not None:
        first_image = load_image(frame_paths[0])
        segmentation = load_image(paths["segmented"] / frame_paths[0].name, grayscale=True)
        front_mask, back_mask = mask_splitter.infer(image=first_image, mask=segmentation)
        dummy_combined = stack_frame(first_image, color_mask(segmentation), front_mask.astype(np.uint8), back_mask.astype(np.uint8))
        height, width, _ = dummy_combined.shape

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 'mp4v' or 'avc1'
        video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))

    for frame_path in tqdm(frame_paths, desc="Inference video"):
        frame_name = frame_path.name
        image = load_image(paths["images"] / frame_name)
        segmentation = load_image(paths["segmented"] / frame_name, grayscale=True)

        if image is None or segmentation is None:
            print(f"[Warning] Missing data for: {frame_name}")
            continue

        front_mask, back_mask = mask_splitter.infer(image=image, mask=segmentation)
        combined = stack_frame(image, color_mask(segmentation), front_mask.astype(np.uint8), back_mask.astype(np.uint8))

        if video_writer:
            video_writer.write(combined)
        else:
            cv2.imshow("Inference - [Original | Ground Truth | Predicted Front & Back]", combined)

        if cv2.waitKey(int(1000 / fps)) == ESCAPE_KEY:
            break

    if video_writer:
        video_writer.release()
    cv2.destroyAllWindows()


def reencode_for_functionality(input_path: str | Path, output_path: str | Path):
    subprocess.run([
        'ffmpeg', '-y', '-i', str(input_path),
        '-vcodec', 'libx264', '-pix_fmt', 'yuv420p',
        '-crf', '23', '-preset', 'veryfast',
        str(output_path)
    ])


if __name__ == "__main__":
    from path_manager import Paths

    dataset_path = Path("/home/user/Desktop/work/data/car-follow/train")
    scene_name = "around-car-90-75-60-45-30-low-quality"
    scene_name = "around-car-30-45-60-75-90-low-quality"
    dataset_path = Path("/home/user/Desktop/work/data/car-follow/validation")
    scene_name = "around-car-45-high-quality"
    model_path = Paths.CAR_MASK_SPLITTER_MODEL_V2_PATH
    model_path = Paths.CAR_MASK_SPLITTER_MODEL_V_LOW_ONLY_PATH
    model_path = "/home/user/Desktop/work/space-time-vision-repos/data-labeling-tool/nn/checkpoints/mask_splitter-epoch_10-dropout_0-low_x1-and-high_x0_quality_early_stop.pt"
    # model_path = "/home/user/Desktop/work/space-time-vision-repos/data-labeling-tool/checkpoints/mask_splitter-epoch_10-dropout_0-low_x2-and-high_x5_quality.pt"
    # model_path = "/home/user/Desktop/work/space-time-vision-repos/data-labeling-tool/checkpoints/mask_splitter-epoch_10-dropout_0-low_x2-and-high_x5_quality_early_stop.pt"
    # model_path = "/home/user/Desktop/work/space-time-vision-repos/data-labeling-tool/checkpoints/mask_splitter-epoch_10-dropout_0-low_x2-and-high_x0_quality.pt"
    # model_path = "/home/user/Desktop/work/space-time-vision-repos/data-labeling-tool/checkpoints/mask_splitter-epoch_10-dropout_0-low_x2-and-high_x0_quality_early_stop.pt"
    fps = 15

    output_video_path = Path("./inference_video.mp4")
    output_video_path_reencoded = Path("./inference_video-reencoded.mp4")
    play_inference_vs_ground_truth_video(dataset_path, scene_name, model_path, fps)
    scene_name = "around-car-45-low-quality"
    play_inference_vs_ground_truth_video(dataset_path, scene_name, model_path, fps)
    scene_name = "around-car-45-low-quality-car-at-45"
    play_inference_vs_ground_truth_video(dataset_path, scene_name, model_path, fps)

    # play_inference_vs_ground_truth_video(dataset_path, scene_name, model_path, fps, output_path=output_video_path)
    # reencode_for_functionality(output_video_path, output_video_path_reencoded)
    #
    # scene_name = "around-car-90-75-60-45-30-low-quality"
    # run_inference_video(dataset_path, scene_name, model_path, fps)
