import os
import csv
import random
from copy import deepcopy
from pathlib import Path
from typing import Iterable, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

_this_dir = Path(__file__).resolve().parent
_paths_py = (_this_dir.parent / "config" / "paths.py").resolve()
if not _paths_py.exists():
    raise FileNotFoundError(f"config/paths.py not found at: {_paths_py}")
import importlib.util as _importlib_util
_spec = _importlib_util.spec_from_file_location("paths", str(_paths_py))
paths_mod = _importlib_util.module_from_spec(_spec)
_spec.loader.exec_module(paths_mod)

img_size = 224
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)
batch_size = 64
epochs = 50
patience = 5
lr = 1e-4
seed = 42
num_workers = 4

model_names = [
    "resnet18", "resnet34", "resnet50",
    "vgg16", "vgg19",
    "densenet121", "densenet161",
    "mobilenet_v2",
]

transform_train = A.Compose([
    A.Resize(img_size, img_size),
    A.HorizontalFlip(p=0.5),
    A.Rotate(p=0.5, limit=(-30, 30)),
    A.RandomBrightnessContrast(p=0.5, brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), brightness_by_max=False),
    A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ToTensorV2(),
])
transform_eval = A.Compose([
    A.Resize(img_size, img_size),
    A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ToTensorV2(),
])

def _seed_everything(s: int = seed):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def _write_csv_row(csv_path: Path, header: Iterable[str], row: Iterable):
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    new_file = not csv_path.exists()
    with open(csv_path, "a", newline="") as f:
        w = csv.writer(f)
        if new_file:
            w.writerow(list(header))
        w.writerow(list(row))

class FERImageFolder(Dataset):
    def __init__(self, csv_or_df, root_dir: Path, transform=None):
        if isinstance(csv_or_df, (str, Path)):
            self.data = pd.read_csv(csv_or_df)
        else:
            self.data = pd.DataFrame(csv_or_df).reset_index(drop=True)
        if "image_id" not in self.data.columns or "label" not in self.data.columns:
            raise ValueError("csv/dataframe must include 'image_id' and 'label'.")
        self.root_dir = Path(root_dir)
        self.transform = transform

    def __len__(self): return len(self.data)

    def __getitem__(self, idx: int):
        row = self.data.iloc[idx]
        img_path = self.root_dir / str(row["image_id"])
        if not img_path.exists():
            raise FileNotFoundError(f"missing image file: {img_path}")
        image = Image.open(img_path).convert("RGB")
        label = int(row["label"])
        if self.transform:
            image = self.transform(image=np.array(image))["image"]
        return image, label

def _prepare_loaders(base_path: Path) -> Tuple[DataLoader, DataLoader, DataLoader]:
    tr_csv = base_path / "new_train.csv"
    va_csv = base_path / "new_val.csv"
    te_csv = base_path / "new_test.csv"
    for p in [tr_csv, va_csv, te_csv]:
        if not p.exists():
            raise FileNotFoundError(f"missing csv: {p}")

    df_train = pd.read_csv(tr_csv)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=seed)
    tr_idx, va_idx = next(splitter.split(df_train["image_id"], df_train["label"]))
    tr_df = df_train.iloc[tr_idx].reset_index(drop=True)
    va_df = df_train.iloc[va_idx].reset_index(drop=True)

    train_root = base_path / "data" / "FER2013Train"
    test_root  = base_path / "data" / "FER2013Test"

    train_ds = FERImageFolder(tr_df, train_root, transform=transform_train)
    val_ds   = FERImageFolder(va_df, train_root, transform=transform_eval)
    test_ds  = FERImageFolder(te_csv, test_root, transform=transform_eval)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader

def adapt_model_for_fer(model: nn.Module, name: str, num_classes: int) -> nn.Module:
    if name in ["resnet18", "resnet34", "resnet50"]:
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name in ["vgg16", "vgg19"]:
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    elif name in ["densenet121", "densenet161"]:
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif name == "mobilenet_v2":
        if isinstance(model.classifier, nn.Sequential):
            for i in reversed(range(len(model.classifier))):
                if isinstance(model.classifier[i], nn.Linear):
                    model.classifier[i] = nn.Linear(model.classifier[i].in_features, num_classes)
                    break
        else:
            model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    else:
        raise ValueError(f"unsupported model: {name}")
    return model

def _train_one_epoch(model, loader, criterion, optimizer, device) -> Tuple[float, float, float]:
    model.train()
    total_loss = 0.0
    preds, labels = [], []
    for inputs, targets in tqdm(loader, desc="train"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        probs = outputs.softmax(1).detach().cpu().numpy()
        preds.extend(probs)
        labels.extend(targets.detach().cpu().numpy())
    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(labels, np.argmax(preds, axis=1))
    try:
        auc = roc_auc_score(labels, preds, multi_class="ovo", average="macro")
    except ValueError:
        auc = float("nan")
    return avg_loss, acc, auc

@torch.no_grad()
def _evaluate(model, loader, device, phase: str) -> Tuple[float, float]:
    model.eval()
    preds, labels = [], []
    for inputs, targets in tqdm(loader, desc=phase):
        inputs = inputs.to(device, non_blocking=True)
        outputs = model(inputs)
        probs = outputs.softmax(1).cpu().numpy()
        preds.extend(probs)
        labels.extend(targets.numpy())
    acc = accuracy_score(labels, np.argmax(preds, axis=1))
    try:
        auc = roc_auc_score(labels, preds, multi_class="ovo", average="macro")
    except ValueError:
        auc = float("nan")
    print(f"{phase} accuracy: {acc:.4f}, auc: {auc:.4f}")
    return acc, auc

def train_ferplus_models():
    _seed_everything(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    paths = paths_mod.resolve_paths(dataset="ferplus", data_root=None, out_root="./outputs")

    base_path = paths["dataset_dir"] / "fer-pytorch" / "fer_pytorch" / "dataset"
    if not base_path.exists():
        raise FileNotFoundError(f"fer+ dataset folder not found: {base_path}")

    train_loader, val_loader, test_loader = _prepare_loaders(base_path)

    labels = []
    for name in ["new_train.csv", "new_val.csv", "new_test.csv"]:
        p = base_path / name
        if p.exists():
            df = pd.read_csv(p)
            if "label" in df.columns:
                labels.extend(df["label"].unique().tolist())
    classes = sorted(set(int(x) for x in labels))
    num_classes = len(classes)
    if num_classes < 2:
        raise ValueError(f"invalid class count: {num_classes}")

    base_results: Path = paths["results_dir"]
    base_results.mkdir(parents=True, exist_ok=True)
    summary_csv = base_results / "summary_ferplus.csv"
    summary_header = [
        "model", "best_epoch", "train_loss", "train_acc", "val_acc", "test_acc",
        "train_auc", "val_auc", "test_auc", "weights_path"
    ]

    weights_root: Path = paths["dataset_dir"] / "models_refined"
    weights_root.mkdir(parents=True, exist_ok=True)

    for name in model_names:
        print("\n--------------------")
        print(name)
        print("--------------------\n")

        model_results = base_results.parent / f"{base_results.name}_{name}"
        model_results.mkdir(parents=True, exist_ok=True)
        metrics_csv  = model_results / "metrics.csv"
        history_path = model_results / "history.pkl"
        weights_path = weights_root / f"{name}.pth"

        ctor = getattr(models, name)
        model = ctor(pretrained=True)
        model = adapt_model_for_fer(model, name, num_classes).to(device)

        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=2, factor=0.5)

        best_val = float("-inf")
        best_epoch = -1
        best_train_acc = float("nan"); best_train_auc = float("nan")
        best_val_auc = float("nan")
        no_improve = 0

        history: Dict[str, list] = {
            "train_loss": [], "train_acc": [], "train_auc": [],
            "val_acc": [], "val_auc": [], "test_acc": [], "test_auc": [],
        }

        for epoch in range(1, epochs + 1):
            print(f"\nepoch {epoch}/{epochs}")
            tr_loss, tr_acc, tr_auc = _train_one_epoch(model, train_loader, criterion, optimizer, device)
            va_acc, va_auc = _evaluate(model, val_loader, device, phase="val")

            history["train_loss"].append(tr_loss)
            history["train_acc"].append(tr_acc)
            history["train_auc"].append(tr_auc)
            history["val_acc"].append(va_acc)
            history["val_auc"].append(va_auc)

            scheduler.step(va_acc)

            if va_acc > best_val:
                best_val = va_acc
                best_val_auc = va_auc
                best_train_acc = tr_acc
                best_train_auc = tr_auc
                best_epoch = epoch
                torch.save(deepcopy(model.state_dict()), weights_path)
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= patience:
                    print(f"early stopping at epoch {epoch}")
                    break

        if weights_path.exists():
            model.load_state_dict(torch.load(weights_path, map_location=device))
        te_acc, te_auc = _evaluate(model, test_loader, device, phase="test")
        history["test_acc"].append(te_acc)
        history["test_auc"].append(te_auc)

        _write_csv_row(
            metrics_csv,
            header=["train_loss", "train_acc", "val_acc", "test_acc", "train_auc", "val_auc", "test_auc", "best_epoch", "weights_path"],
            row=[f"{history['train_loss'][-1]:.4f}", f"{best_train_acc:.4f}", f"{best_val:.4f}", f"{te_acc:.4f}",
                 f"{best_train_auc:.4f}", f"{best_val_auc:.4f}", f"{te_auc:.4f}", best_epoch, str(weights_path)]
        )

        with open(history_path, "wb") as f:
            import pickle
            pickle.dump(history, f)

        _write_csv_row(
            summary_csv,
            header=summary_header,
            row=[name, best_epoch,
                 f"{history['train_loss'][-1]:.4f}", f"{best_train_acc:.4f}",
                 f"{best_val:.4f}", f"{te_acc:.4f}",
                 f"{best_train_auc:.4f}", f"{best_val_auc:.4f}", f"{te_auc:.4f}",
                 str(weights_path)]
        )

        del model
        torch.cuda.empty_cache()

if __name__ == "__main__":
    train_ferplus_models()
