import os
import time
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, JaccardIndex, F1Score
from tqdm import tqdm

from nn.callbacks import EarlyStopping
from nn.dataset_car_segmentation import CarSegmentationDataset, AdvancedTransform
from nn.losses import PartitionLoss, LossWeightScheduler
from nn.net_mask_splitter import MaskSplitterNet
from nn.saving import save_model, save_model_results


def constrain_to_original_mask(pred_logits, original_mask, method='normalize'):
    """
    Ensure predictions are constrained to original mask and don't overlap
    """
    pred_sig = torch.sigmoid(pred_logits)  # (B, 2, H, W)
    front_pred, back_pred = pred_sig[:, 0:1], pred_sig[:, 1:2]

    if method == 'normalize':
        # Apply mask and normalize so front + back = original_mask
        front_masked = front_pred * original_mask
        back_masked = back_pred * original_mask

        # Normalize so front + back = original_mask (avoid division by zero)
        total = front_masked + back_masked + 1e-8
        front_final = front_masked * original_mask / total
        back_final = back_masked * original_mask / total

        return torch.cat([front_final, back_final], dim=1)

    elif method == 'winner_takes_all':
        front_pred, back_pred = pred_sig[:, 0:1], pred_sig[:, 1:2]
        front_wins = (front_pred > back_pred).float() * original_mask
        back_wins = original_mask - front_wins

        return torch.cat([front_wins, back_wins], dim=1)
    raise NotImplementedError(f"Method {method} not implemented")


def compute_metrics_partition(
        pred_logits: torch.Tensor,
        target_masks: torch.Tensor,
        original_mask: torch.Tensor,
        accuracy_metric,
        iou_metric,
        f1_score_metric,
        threshold: float = 0.5
) -> dict:
    """
    Compute metrics for partition segmentation with original mask constraint
    Returns: avg_acc, avg_iou, avg_f1, partition_quality
    """
    pred_constrained = constrain_to_original_mask(pred_logits, original_mask)
    pred_binary = (pred_constrained > threshold).int()
    target_binary = (target_masks > threshold).int()
    metrics = {}
    for i, part_name in enumerate(['front', 'back']):
        pred_part = pred_binary[:, i]
        target_part = target_binary[:, i]

        metrics[f'{part_name}_acc'] = accuracy_metric(pred_part, target_part).item()
        metrics[f'{part_name}_iou'] = iou_metric(pred_part, target_part).item()
        metrics[f'{part_name}_f1'] = f1_score_metric(pred_part, target_part).item()

    metrics['avg_acc'] = (metrics['front_acc'] + metrics['back_acc']) / 2
    metrics['avg_iou'] = (metrics['front_iou'] + metrics['back_iou']) / 2
    metrics['avg_f1'] = (metrics['front_f1'] + metrics['back_f1']) / 2

    total_pred = pred_binary[:, 0] + pred_binary[:, 1]
    original_binary = (original_mask.squeeze(1) > threshold).int()

    metrics['perfect_partition'] = (total_pred == original_binary).float().mean().item()

    coverage = (total_pred >= original_binary).float().mean().item()  # How much we cover
    precision = ((total_pred == 1) & (original_binary == 1)).sum().float() / (total_pred == 1).sum().clamp(
        min=1).float()
    metrics['coverage'] = coverage
    metrics['partition_precision'] = precision.item()

    overlap = (pred_binary[:, 0] * pred_binary[:, 1]).float().mean().item()
    metrics['overlap_rate'] = overlap

    return metrics


def train_step(model, dataloader, loss_fn, optimizer, device, accuracy_metric, iou_metric, f1_score_metric, metric_fn):
    model.train()
    total_metrics = {}
    total_loss = 0
    num_batches = 0

    for inputs, targets in tqdm(dataloader, desc="Training"):  # inputs is (B, 4, H, W)
        inputs, targets = inputs.to(device), targets.to(device)
        original_mask = inputs[:, 3:4]

        preds = model(inputs)
        loss, loss_components = loss_fn(preds, targets, original_mask)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_metrics = metric_fn(preds, targets, original_mask, accuracy_metric, iou_metric, f1_score_metric)
        total_loss += loss.item()
        for key, value in batch_metrics.items():
            if key not in total_metrics:
                total_metrics[key] = 0
            total_metrics[key] += value
        for key, value in loss_components.items():
            loss_key = f"{key}"
            if loss_key not in total_metrics:
                total_metrics[loss_key] = 0
            total_metrics[loss_key] += value
        num_batches += 1

    total_metrics = {k: v / num_batches for k, v in total_metrics.items()}
    total_metrics['loss'] = total_loss / num_batches
    return total_metrics


def test_step(model, dataloader, loss_fn, device, accuracy_metric, iou_metric, f1_score_metric, metric_fn):
    model.eval()
    total_metrics = {}
    total_loss = 0
    num_batches = 0

    with torch.inference_mode():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            original_mask = inputs[:, 3:4]
            preds = model(inputs)
            loss, loss_components = loss_fn(preds, targets, original_mask)

            batch_metrics = metric_fn(
                preds, targets, original_mask,
                accuracy_metric, iou_metric, f1_score_metric
            )
            total_loss += loss.item()
            for key, value in batch_metrics.items():
                if key not in total_metrics:
                    total_metrics[key] = 0
                total_metrics[key] += value

            for key, value in loss_components.items():
                loss_key = f"{key}"
                if loss_key not in total_metrics:
                    total_metrics[loss_key] = 0
                total_metrics[loss_key] += value

            num_batches += 1

    total_metrics = {k: v / num_batches for k, v in total_metrics.items()}
    total_metrics['loss'] = total_loss / num_batches

    return total_metrics


def train_loop(
        model: nn.Module,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        optimizer: torch.optim.Optimizer,
        val_loss_fn: nn.Module,
        device: torch.device,
        epochs: int = 100,
        scheduler: torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau = None,
        metric_fn=compute_metrics_partition,
        save_dir: str = "./checkpoints/",
        model_name: str = "mask_splitter_net",
        loss_weight_scheduler: LossWeightScheduler | None = None,
):
    results = {
        'train_loss': [], 'val_loss': [],
        'train_avg_acc': [], 'val_avg_acc': [],
        'train_avg_iou': [], 'val_avg_iou': [],
        'train_avg_f1': [], 'val_avg_f1': [],
        'train_front_acc': [], 'val_front_acc': [],
        'train_back_acc': [], 'val_back_acc': [],
        'train_front_iou': [], 'val_front_iou': [],
        'train_back_iou': [], 'val_back_iou': [],
        'train_front_f1': [], 'val_front_f1': [],
        'train_back_f1': [], 'val_back_f1': [],
        'train_perfect_partition': [], 'val_perfect_partition': [],
        'train_overlap_rate': [], 'val_overlap_rate': [],
        'train_coverage': [], 'val_coverage': [],
        'train_individual_loss': [], 'val_individual_loss': [],
        'train_partition_loss': [], 'val_partition_loss': [],
        'train_overlap_penalty': [], 'val_overlap_penalty': [],
        'train_coverage_loss': [], 'val_coverage_loss': []
    }
    early_stopping = EarlyStopping(patience=20, verbose=True, save_dir=save_dir)
    accuracy = Accuracy(task="binary").to(device)
    iou = JaccardIndex(task="binary").to(device)
    f1 = F1Score(task="binary").to(device)
    print(f"Training for {epochs} epochs...")

    for epoch in range(epochs):
        loss_weight_scheduler.step(epoch)
        start = time.time()
        train_metrics = train_step(
            model, train_dataloader, loss_weight_scheduler.loss_fn, optimizer, device, accuracy, iou, f1, metric_fn
        )
        val_metrics = test_step(model, val_dataloader, val_loss_fn, device, accuracy, iou, f1, metric_fn)
        end = time.time()

        for key in results.keys():
            if key.startswith('train_'):
                metric_key = key[6:]
                if metric_key in train_metrics:
                    results[key].append(train_metrics[metric_key])
                else:
                    results[key].append(0.0)
            elif key.startswith('val_'):
                metric_key = key[4:]
                if metric_key in val_metrics:
                    results[key].append(val_metrics[metric_key])
                else:
                    results[key].append(0.0)

        current_val_loss = val_metrics['loss']
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(current_val_loss)
            else:
                scheduler.step()

        print(f"Epoch {epoch + 1:03d}/{epochs} | Time: {end - start:.2f}s | "
              f"TrLoss: {train_metrics['loss']:.4f} | VaLoss: {val_metrics['loss']:.4f} | "
              f"TrAcc: {train_metrics.get('avg_acc', 0):.4f} | VaAcc: {val_metrics.get('avg_acc', 0):.4f} | "
              f"TrIoU: {train_metrics.get('avg_iou', 0):.4f} | VaIoU: {val_metrics.get('avg_iou', 0):.4f} | "
              f"TrF1: {train_metrics.get('avg_f1', 0):.4f} | VaF1: {val_metrics.get('avg_f1', 0):.4f} | "
              f"Partition: {val_metrics.get('perfect_partition', 0):.4f} | Overlap: {val_metrics.get('overlap_rate', 0):.4f} | "
              f"Individual: {val_metrics.get('individual_loss', 0):.4f} | PartLoss: {val_metrics.get('partition_loss', 0):.4f} | "
              f"OverlapPen: {val_metrics.get('overlap_penalty', 0):.4f}")

        print(
            f"  Front Metrics  - TrAcc: {train_metrics.get('front_acc', 0):.4f} | VaAcc: {val_metrics.get('front_acc', 0):.4f} | "
            f"TrIoU: {train_metrics.get('front_iou', 0):.4f} | VaIoU: {val_metrics.get('front_iou', 0):.4f} | "
            f"TrF1: {train_metrics.get('front_f1', 0):.4f} | VaF1: {val_metrics.get('front_f1', 0):.4f}")

        print(
            f"  Back Metrics   - TrAcc: {train_metrics.get('back_acc', 0):.4f} | VaAcc: {val_metrics.get('back_acc', 0):.4f} | "
            f"TrIoU: {train_metrics.get('back_iou', 0):.4f} | VaIoU: {val_metrics.get('back_iou', 0):.4f} | "
            f"TrF1: {train_metrics.get('back_f1', 0):.4f} | VaF1: {val_metrics.get('back_f1', 0):.4f}")
        print(f"Epoch {epoch + 1}: alpha={loss_weight_scheduler.loss_fn.alpha:.3f}, "
              f"beta={loss_weight_scheduler.loss_fn.beta:.3f}, gamma={loss_weight_scheduler.loss_fn.gamma:.3f}")
        print("-" * 100)
        accuracy.reset()
        iou.reset()
        f1.reset()

        early_stopping(val_loss=val_metrics.get("loss"), model=model, model_name=f"{model_name}_early_stop.pt")
        if early_stopping.early_stop:
            print(f"Early stopping triggered. Restoring best model with val_loss={early_stopping.best_score:.4f}")
            model.load_state_dict(early_stopping.best_model)
            break

    return results


def run_train(
        allowed,
        scene_multi,
        allowed_val,
        data_dir: str | Path = "/home/user/Desktop/work/data/car-follow/",
        epochs: int = 10,
        batch_size: int = 8,
        lr: float = 1e-4,
        dropout: float = 0.0,
        save_dir: str = "./checkpoints",
        high_quality_train_multi: int = 5,
        low_quality_train_multi: int = 3,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on {device}")



    os.makedirs(save_dir, exist_ok=True)

    train_data_path = Path(data_dir) / "train"
    val_data_path = Path(data_dir) / "validation"
    train_dataset = CarSegmentationDataset(
        root_dir=train_data_path,
        image_size=(360, 640),
        allowed_scenes=allowed,
        transform=AdvancedTransform(),
        scene_multipliers=scene_multi,
    )
    val_dataset = CarSegmentationDataset(root_dir=val_data_path, image_size=(360, 640), allowed_scenes=allowed_val)
    print(f"Train Dataset size: {len(train_dataset)}")
    print(f"Validation Dataset size: {len(val_dataset)}")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)

    model_name = (
        f"mask_splitter-epoch_{epochs}-dropout_{int(dropout * 100)}-"
        f"low_x{low_quality_train_multi}-and-high_x{high_quality_train_multi}_quality"
    )
    out_classes = 2
    # RGB (3) + mask (1) => input with 4 channels
    model = MaskSplitterNet(in_channels=4, out_channels=out_classes, dropout_rate=dropout).to(device)
    input_tensor, _ = train_dataset[0]
    model.display_summary((batch_size, *input_tensor.shape))

    tr_loss_fn = PartitionLoss()
    loss_weight_scheduler = LossWeightScheduler(
        tr_loss_fn, total_epochs=epochs,
        start_weights=(0.9, 0.05, 0.05),
        end_weights=(0.4, 0.4, 0.2),
    )
    val_loss_fn = PartitionLoss(0.4, 0.4, 0.2)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

    results = train_loop(
        model=model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        optimizer=optimizer,
        val_loss_fn=val_loss_fn,
        epochs=epochs,
        device=device,
        scheduler=None,
        metric_fn=compute_metrics_partition,
        model_name=model_name,
        save_dir=save_dir,
        loss_weight_scheduler=loss_weight_scheduler
    )

    save_model(model, save_dir, f"{model_name}.pt")
    save_model_results(save_dir, f"{model_name}", results)


if __name__ == '__main__':
    import argparse

    hq_multi = 0
    lq_multi = 0

    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="/home/user/Desktop/work/data/car-follow/")
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--save_dir", type=str, default="./checkpoints/")
    parser.add_argument("--hq_multi", type=int, default=hq_multi, help="High quality scene multiplier")
    parser.add_argument("--lq_multi", type=int, default=lq_multi, help="Low quality scene multiplier")
    args = parser.parse_args()

    hq_multi = 0
    lq_multi = 2

    allowed = [
        # "around-car-30-45-60-75-90-high-quality",
        "around-car-90-75-60-45-30-low-quality",
        "around-car-30-45-60-75-90-low-quality",
        # "just-environment-high-quality",
        "just-environment-low-quality"
    ]
    scene_multi = {
        # "around-car-30-45-60-75-90-high-quality": hq_multi,
        "around-car-90-75-60-45-30-low-quality": lq_multi,
        "around-car-30-45-60-75-90-low-quality": lq_multi,
        # "just-environment-high-quality": 2,
        "just-environment-low-quality": 2
    }

    allowed_val = [
        # "around-car-45-high-quality",
        "around-car-45-low-quality",
        "around-car-45-low-quality-car-at-45"
    ]

    run_train(
        allowed=allowed,
        scene_multi=scene_multi,
        allowed_val=allowed_val,
        data_dir=args.data_dir,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        dropout=args.dropout,
        save_dir=args.save_dir,
        high_quality_train_multi=hq_multi,
        low_quality_train_multi=lq_multi,
    )

    # TRAIN WITH ALL

    hq_multi = 5
    lq_multi = 2

    allowed = [
        "around-car-30-45-60-75-90-high-quality",
        "around-car-90-75-60-45-30-low-quality",
        "around-car-30-45-60-75-90-low-quality",
        "just-environment-high-quality",
        "just-environment-low-quality"
    ]
    scene_multi = {
        "around-car-30-45-60-75-90-high-quality": hq_multi,
        "around-car-90-75-60-45-30-low-quality": lq_multi,
        "around-car-30-45-60-75-90-low-quality": lq_multi,
        "just-environment-high-quality": 2,
        "just-environment-low-quality": 2
    }

    allowed_val = [
        "around-car-45-high-quality",
        "around-car-45-low-quality",
        "around-car-45-low-quality-car-at-45"
    ]

    run_train(
        allowed=allowed,
        scene_multi=scene_multi,
        allowed_val=allowed_val,
        data_dir=args.data_dir,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        dropout=args.dropout,
        save_dir=args.save_dir,
        high_quality_train_multi=hq_multi,
        low_quality_train_multi=lq_multi,
    )
