import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import models
from pathlib import Path


# compute ece only
def compute_ece(model, loader, device, n_bins=15):
    model.eval()
    confidences, predictions, labels = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="computing reliability data"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            conf, preds = torch.max(probs, dim=1)
            confidences.extend(conf.cpu().numpy())
            predictions.extend(preds.cpu().numpy())
            labels.extend(targets.cpu().numpy())

    confidences = np.array(confidences)
    predictions = np.array(predictions)
    labels = np.array(labels)

    bin_edges = np.linspace(0, 1, n_bins + 1)
    accs = np.zeros(n_bins)
    confs = np.zeros(n_bins)
    props = np.zeros(n_bins)

    for i in range(n_bins):
        bin_mask = (confidences > bin_edges[i]) & (confidences <= bin_edges[i + 1])
        props[i] = np.mean(bin_mask)
        if props[i] > 0:
            accs[i] = np.mean(predictions[bin_mask] == labels[bin_mask])
            confs[i] = np.mean(confidences[bin_mask])

    ece = np.sum(np.abs(accs - confs) * props)
    return ece

# top-1 accuracy
def compute_accuracy(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)

    return correct / total


# compute conditional entropy
def compute_conditional_entropy_per_image(model, loader, device):
    model.eval()
    all_entropies = []

    with torch.no_grad():
        for inputs, _ in tqdm(loader, desc="computing conditional entropy"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            entropy = -np.sum(probs * np.log(probs + 1e-12), axis=1)
            all_entropies.extend(entropy)

    return np.array(all_entropies)


def get_model(name, dataset, device, weights_root="data/ferplus/models_refined"):
    model_fn = getattr(models, name)
    model = model_fn(pretrained=False)
    num_classes = len(dataset.classes)

    if "resnet" in name:
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif "vgg" in name:
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    elif "densenet" in name:
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    elif name.startswith("mobilenet") or name.startswith("efficientnet"):
        if isinstance(model.classifier, nn.Sequential):
            for i in reversed(range(len(model.classifier))):
                if isinstance(model.classifier[i], nn.Linear):
                    in_features = model.classifier[i].in_features
                    model.classifier[i] = nn.Linear(in_features, num_classes)
                    break
        else:
            in_features = model.classifier.in_features
            model.classifier = nn.Linear(in_features, num_classes)

    weights_path = Path(weights_root) / f"{name}.pth"
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()
    print(name, "loaded from disk.")
    model = model.to(device)
    return model


# lac
def compute_nonconformity_scores(model, loader, device):
    model.eval()
    nonconformity_scores, true_labels = [], []

    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="computing nonconformit scores"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            true_probs = probabilities[torch.arange(len(targets)), targets]
            scores = 1 - true_probs
            nonconformity_scores.extend(scores.cpu().numpy())
            true_labels.extend(targets.cpu().numpy())

    return np.array(nonconformity_scores), np.array(true_labels)


def determine_threshold(nonconformity_scores, alpha):
    n = len(nonconformity_scores)
    k = int(np.ceil((n + 1) * (1 - alpha)))
    threshold = np.sort(nonconformity_scores)[min(k - 1, n - 1)]
    return threshold


def predict_with_lac(model, loader, threshold, device):
    model.eval()
    prediction_sets = []

    with torch.no_grad():
        for inputs, _ in tqdm(loader, desc="predicting"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            preds = probabilities >= (1 - threshold)
            batch_pred_sets = [pred.nonzero(as_tuple=True)[0].cpu().tolist() for pred in preds]
            prediction_sets.extend(batch_pred_sets)

    return prediction_sets


# aps
def compute_aps_scores(model, loader, device, random_state=None):
    rng = np.random.default_rng(random_state)
    model.eval()
    aps_scores = []

    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="aps calibration"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()

            sorted_probs = -np.sort(-probabilities, axis=1)
            sorted_indices = np.argsort(-probabilities, axis=1)
            cumulative_probs = np.cumsum(sorted_probs, axis=1)

            for i, target in enumerate(targets.cpu().numpy()):
                rank = np.where(sorted_indices[i] == target)[0][0]
                p_true = probabilities[i][target]
                u = rng.uniform()
                score = cumulative_probs[i][rank] - u * p_true
                aps_scores.append(score)

    return np.array(aps_scores)


def predict_with_aps(model, loader, threshold, device):
    model.eval()
    prediction_sets = []

    with torch.no_grad():
        for inputs, _ in tqdm(loader, desc="aps prediction"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()

            sorted_indices = np.argsort(-probabilities, axis=1)
            sorted_probs = -np.sort(-probabilities, axis=1)
            cumulative_probs = np.cumsum(sorted_probs, axis=1)

            for cum_prob, indices in zip(cumulative_probs, sorted_indices):
                idx = np.searchsorted(cum_prob, threshold, side="left")
                idx = min(idx, len(indices) - 1)
                prediction_sets.append(indices[: idx + 1].tolist())

    return prediction_sets


# raps
def compute_raps_scores(
    model,
    loader,
    device,
    lambda_reg=0.1,
    k_reg=5,
    beta=1.0,
    calibration_fraction=0.9,
    random_state=None,
):
    rng = np.random.default_rng(random_state)
    model.eval()
    all_probs, all_targets = [], []

    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="collecting calibration data"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            all_probs.append(probs)
            all_targets.append(targets.cpu().numpy())

    all_probs = np.vstack(all_probs)
    all_targets = np.concatenate(all_targets)

    n_total = len(all_targets)
    n_calib = int(n_total * calibration_fraction)
    indices = rng.permutation(n_total)
    calib_indices = indices[:n_calib]

    probs_calib = all_probs
    targets_calib = all_targets
    scores = []

    for i in range(len(probs_calib)):
        probs = probs_calib[i]
        y = targets_calib[i]

        sorted_indices = np.argsort(-probs)
        sorted_probs = probs[sorted_indices]
        cumulative_probs = np.cumsum(sorted_probs)
        rank = np.where(sorted_indices == y)[0][0]
        p_true = probs[y]

        u = rng.uniform()
        aps_score = cumulative_probs[rank] - u * p_true
        penalty = lambda_reg * max((rank + 1) - k_reg, 0) ** beta
        score = aps_score + penalty
        scores.append(score)

    return np.array(scores)


def predict_with_raps(model, loader, threshold, device):
    model.eval()
    prediction_sets = []

    with torch.no_grad():
        for inputs, _ in tqdm(loader, desc="raps prediction"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()

            sorted_indices = np.argsort(-probabilities, axis=1)
            sorted_probs = np.sort(probabilities, axis=1)[:, ::-1]
            cumulative_probs = np.cumsum(sorted_probs, axis=1)

            for cum_prob, indices in zip(cumulative_probs, sorted_indices):
                idx = np.searchsorted(cum_prob, threshold, side="left")
                idx = min(idx, len(indices) - 1)
                prediction_sets.append(indices[: idx + 1].tolist())

    return prediction_sets


# metrics
def compute_coverage(prediction_sets, y_true):
    if not prediction_sets:
        return 0.0
    covered = [y in pred for y, pred in zip(y_true, prediction_sets)]
    return np.mean(covered)


def compute_size_stratified_coverage(prediction_sets, y_true):
    set_size_groups = {}
    for i, pred_set in enumerate(prediction_sets):
        size = len(pred_set)
        if size not in set_size_groups:
            set_size_groups[size] = []
        set_size_groups[size].append(y_true[i] in pred_set)

    min_coverage = min(np.mean(coverage) for coverage in set_size_groups.values())
    return min_coverage


def compute_average_width(prediction_sets):
    return np.mean([len(p) for p in prediction_sets])
