from pathlib import Path

import cv2
import numpy as np

from nn.infer import MaskSplitterInference
from nn.vis_nn_spliter_video import get_frame_paths, load_image
from path_manager import Paths, get_path_of_files

if __name__ == '__main__':
    images = get_path_of_files(directory=Paths.TEST_IMAGES_DIR, file_type=".png")
    masks = get_path_of_files(directory=Paths.TEST_MASKS_DIR, file_type=".png")
    model_path = Paths.CAR_MASK_SPLITTER_MODEL_V2_PATH
    mask_splitter = MaskSplitterInference(model_path=model_path, is_model_compiled=True)
    img = None
    msk = None
    for image_path, mask_path in zip(images, masks):
        print(f"{Path(image_path).name} -> {Path(mask_path).name}")
        img = cv2.imread(image_path, cv2.IMREAD_COLOR)
        msk = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        front, back = mask_splitter.infer(img, msk)
        mask_splitter.visualize(img, front, back)

    dataset_path = Path("/home/user/Desktop/work/data/car-follow/train")
    scene_name = "around-car-30-45-60-75-90-high-quality"

    mask_splitter.warm_up(image=img, mask=msk, num_iterations=10)
    images = []
    masks = []

    paths = get_frame_paths(dataset_path, scene_name)
    frame_paths = sorted(paths["images"].glob("*.png"))
    for frame_path in frame_paths:
        frame_name = frame_path.name
        image = load_image(paths["images"] / frame_name)
        segmentation = load_image(paths["segmented"] / frame_name, grayscale=True)
        if image is None or segmentation is None:
            print(f"[Warning] Missing data for: {frame_name}")
            continue
        images.append(image)
        masks.append(segmentation)
    print("Prepared data... Doing inference testing...")
    import time

    inference_times = []
    start_time = time.perf_counter()
    for image, mask in zip(images, masks):
        start_time_inference = time.perf_counter()
        front_mask, back_mask = mask_splitter.infer(image, mask)
        end_time_inference = time.perf_counter()
        inference_times.append(end_time_inference - start_time_inference)
    end_time = time.perf_counter()
    print(f"[INFO] Total inference for {len(images)} took {end_time - start_time} seconds.")
    print(f"[INFO] Inference times (mean): {np.mean(inference_times)}")
    print(f"[INFO] Inference times (median): {np.median(inference_times)}")
    print(f"[INFO] Inference times (std): {np.std(inference_times)}")
    print(f"[INFO] FPS: {1.0 / np.mean(inference_times)}")
