import random
from pathlib import Path
from typing import Tuple, List

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch.utils.data import Dataset
from torchvision.transforms import functional, InterpolationMode


class AdvancedTransform:
    def __init__(
            self,
            image_size=(360, 640),
            flip_prob=0.5,
            rotate_deg=15,
            shear_deg=10,
            zoom_range=(1.0, 1.2),
            brightness=0.15,
            saturation=0.25,
            noise_prob=0.0014  # 0.14% = 0.0014
    ):
        self.image_size = image_size
        self.flip_prob = flip_prob
        self.rotate_deg = rotate_deg
        self.shear_deg = shear_deg
        self.zoom_range = zoom_range
        self.brightness = brightness
        self.saturation = saturation
        self.noise_prob = noise_prob

        self.total_pixels = image_size[0] * image_size[1]
        self.noise_pixels = int(self.total_pixels * noise_prob) if noise_prob > 0 else 0

    def __call__(self, input_tensor, target_tensor):
        # ----- Flipping -----
        if random.random() < self.flip_prob:
            input_tensor = torch.flip(input_tensor, dims=[2])  # horizontal
            target_tensor = torch.flip(target_tensor, dims=[2])
        if random.random() < self.flip_prob:
            input_tensor = torch.flip(input_tensor, dims=[1])  # vertical
            target_tensor = torch.flip(target_tensor, dims=[1])

        # ----- Geometric Transform: Affine -----
        angle = random.uniform(-self.rotate_deg, self.rotate_deg)
        shear_x = random.uniform(-self.shear_deg, self.shear_deg)
        shear_y = random.uniform(-self.shear_deg, self.shear_deg)
        scale = random.uniform(*self.zoom_range)

        input_tensor = functional.affine(
            input_tensor, angle=angle, translate=[0, 0], scale=scale,
            shear=[shear_x, shear_y], interpolation=InterpolationMode.BILINEAR
        )
        target_tensor = functional.affine(
            target_tensor, angle=angle, translate=[0, 0], scale=scale,
            shear=[shear_x, shear_y], interpolation=InterpolationMode.NEAREST
        )

        # ----- Brightness & Saturation (RGB only) -----
        rgb = input_tensor[:3]
        mask = input_tensor[3:]

        saturation_factor = 1.0 + random.uniform(-self.saturation, self.saturation)
        brightness_factor = 1.0 + random.uniform(-self.brightness, self.brightness)

        gray = rgb.mean(dim=0, keepdim=True)
        rgb = torch.clamp((rgb - gray) * saturation_factor + gray, 0.0, 1.0)
        rgb = torch.clamp(rgb * brightness_factor, 0.0, 1.0)

        # ----- Salt & Pepper Noise -----
        if self.noise_pixels > 0:
            h, w = rgb.shape[1], rgb.shape[2]

            noise_positions = torch.randint(0, h * w, (self.noise_pixels,))
            noise_values = torch.randint(0, 2, (self.noise_pixels,), dtype=rgb.dtype)
            noise_y = noise_positions // w
            noise_x = noise_positions % w
            rgb[:, noise_y, noise_x] = noise_values.unsqueeze(0)

        input_tensor = torch.cat([rgb, mask], dim=0)
        return input_tensor, target_tensor


class CarSegmentationDataset(Dataset):
    def __init__(
            self,
            root_dir: str | Path,
            image_size: Tuple[int, int] = (360, 640),
            transform=None,
            allowed_scenes: List[str] = None,
            scene_multipliers: dict[str, int] = None
    ):
        self.root_dir = Path(root_dir)
        self.image_size = image_size
        self.transform = transform

        self.samples = []
        scene_multipliers = scene_multipliers or {}

        images_root = self.root_dir / "images"
        segmented_root = self.root_dir / "segmented"
        labels_root = self.root_dir / "labels"

        scene_dirs = [d for d in images_root.iterdir() if d.is_dir()]
        for scene_dir in scene_dirs:
            scene_name = scene_dir.name
            if allowed_scenes and scene_name not in allowed_scenes:
                continue

            image_paths = sorted((images_root / scene_name).glob("*.png"))
            mask_paths = sorted((segmented_root / scene_name).glob("*.png"))
            front_paths = sorted((labels_root / scene_name / "front").glob("*.png"))
            back_paths = sorted((labels_root / scene_name / "back").glob("*.png"))

            if not (len(image_paths) == len(mask_paths) == len(front_paths) == len(back_paths)):
                print(f"[WARNING] Skipping scene {scene_name}: Mismatch in file counts.")
                continue

            scene_samples = list(zip(image_paths, mask_paths, front_paths, back_paths))

            multiplier = scene_multipliers.get(scene_name, 1)
            # Only apply augmentation on multiplied data to keep the others intact.
            for i in range(multiplier):
                is_augmented = i > 0
                for sample in scene_samples:
                    self.samples.append((sample, is_augmented))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        (image_path, mask_path, front_path, back_path), is_augmented = self.samples[index]

        img = cv2.imread(str(image_path))
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        front = cv2.imread(str(front_path), cv2.IMREAD_GRAYSCALE)
        back = cv2.imread(str(back_path), cv2.IMREAD_GRAYSCALE)

        img = cv2.resize(img, self.image_size[::-1])
        mask = cv2.resize(mask, self.image_size[::-1])
        front = cv2.resize(front, self.image_size[::-1])
        back = cv2.resize(back, self.image_size[::-1])

        img = img.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0
        front = front.astype(np.float32) / 255.0
        back = back.astype(np.float32) / 255.0

        assert np.allclose(front + back, mask, atol=0.1), \
            f"Front+back doesn't match mask at {index} -> {Path(image_path).name}"

        mask = np.expand_dims(mask, axis=-1)
        input_4ch = np.concatenate([img, mask], axis=-1)

        target = np.stack([front, back], axis=0)

        input_tensor = torch.from_numpy(input_4ch.transpose(2, 0, 1))  # CHW
        target_tensor = torch.from_numpy(target)

        if self.transform and is_augmented:
            input_tensor, target_tensor = self.transform(input_tensor, target_tensor)

        return input_tensor, target_tensor


def show_batch(inputs, targets):
    """
    Visualizes a batch of inputs (RGB + mask) and targets (front/back).

    :param inputs: Tensor of shape (B, 4, H, W)
    :param targets: Tensor of shape (B, 2, H, W)
    """
    batch_size = inputs.shape[0]
    for i in range(batch_size):
        input_4ch = inputs[i].numpy()  # shape: (4, H, W)
        target_2ch = targets[i].numpy()  # shape: (2, H, W)

        rgb = input_4ch[:3].transpose(1, 2, 0)  # H x W x 3
        seg_mask = input_4ch[3]  # H x W
        front_mask = target_2ch[0]
        back_mask = target_2ch[1]

        fig, axs = plt.subplots(1, 4, figsize=(16, 4))
        axs[0].imshow(rgb)
        axs[0].set_title("RGB Image")
        axs[1].imshow(seg_mask, cmap='gray')
        axs[1].set_title("Segmentation Mask")
        axs[2].imshow(front_mask, cmap='gray')
        axs[2].set_title("Front Mask")
        axs[3].imshow(back_mask, cmap='gray')
        axs[3].set_title("Back Mask")
        for ax in axs:
            ax.axis("off")
        plt.tight_layout()
        plt.show()


def validate_dataset(dataset):
    from tqdm import tqdm

    mismatch_count = 0

    for idx in tqdm(range(len(dataset)), desc="Validating dataset"):
        try:
            _, _ = dataset[idx]
        except AssertionError as e:
            print(f"[ERROR] {e}")
            mismatch_count += 1

    print(f"\nValidation finished: {mismatch_count} mismatches out of {len(dataset)} samples.")


if __name__ == '__main__':
    from torch.utils.data import DataLoader

    data_root_directory = "/home/user/Desktop/work/data/car-follow/train"
    allowed = [
        # "around-car-30-45-60-75-90-high-quality",
        "around-car-90-75-60-45-30-low-quality",
        # "around-car-30-45-60-75-90-low-quality",
        # "just-environment-high-quality",
        # "just-environment-low-quality"
    ]
    scene_multi = {
        "around-car-30-45-60-75-90-high-quality": 4,
        "around-car-90-75-60-45-30-low-quality": 4,
        "just-environment-high-quality": 2,
        "just-environment-low-quality": 2
    }

    t = AdvancedTransform()

    dataset = CarSegmentationDataset(
        root_dir=data_root_directory,
        image_size=(360, 640),
        allowed_scenes=allowed,
        transform=t,
        scene_multipliers=None,
    )

    validate_dataset(dataset)
    #
    # data_root_directory = "/home/user/Desktop/work/data/car-follow/validation"
    #
    # dataset = CarSegmentationDataset(
    #     root_dir=data_root_directory,
    #     image_size=(360, 640),
    #     allowed_scenes=None,
    #     transform=t,
    #     scene_multipliers=None,
    # )
    # validate_dataset(dataset)

    in_tensor, out_tensor = dataset[0]
    print(f"Input tensor shape: {(4, *in_tensor.shape)}")  # should be (4, 360, 640)
    print(f"Target tensor shape: {out_tensor.shape}")  # should be (2, 360, 640)
    print(f"Dataset size: {len(dataset)}")
    loader = DataLoader(dataset, batch_size=4, shuffle=True)

    for batch_inputs, batch_targets in loader:
        print("Inputs shape:", batch_inputs.shape)  # (B, 4, H, W)
        print("Targets shape:", batch_targets.shape)  # (B, 2, H, W)
        show_batch(batch_inputs, batch_targets)
        break
