import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import matplotlib.pyplot as plt
from models.donn import DiffractiveClassifier_RGB
from models.unet import UNet
from models.cifar_like.resnet import ResNet18
import os
import argparse

print("PyTorch version:", torch.__version__)

def evaluate(model, loader, encoder=None,
             using_encoder=False, criterion=None,
             device='cpu', task_id=None, return_cams=False):

    model.layer4[-1].register_forward_hook(get_activation('last_block'))
    model.to(device)
    model.eval()

    if using_encoder:
        assert encoder is not None, "encoder must be provided if using_encoder is True"
        encoder.to(device)
        encoder.eval()

    correct = 0
    total = 0
    total_loss = 0.
    with torch.no_grad():
        for images, labels in loader:
            labels = labels[:, task_id].float()
            images, labels = images.to(device), labels.to(device)

            if using_encoder:
                encoded = encoder(images)
                outputs = model(normalize(encoded)).flatten()
            else:
                outputs = model(normalize(images)).flatten()

            if return_cams:
                cam = calc_cam_resnet_binary(activation['last_block'], model.linear.weight).unsqueeze(1)
                cam_mean = F.relu(cam) + F.relu(cam*-1.0)

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    loss = total_loss / len(loader)

    if return_cams:
        return accuracy, loss, images, encoded, cam_mean
    return accuracy, loss

activation = {}
def get_activation(name):
    """
    Get a forward hook to store the activation of a layer.

    Args:
    - name: Name of the layer

    Returns:
    - A hook function that stores the activation of the layer in the activation dictionary.
    """
    def hook(model, input, output):
        activation[name] = output
    return hook

def calc_cam_resnet_binary(activation, fc_weight):
    """
    Calculate the class activation map (CAM) for a binary classifier.

    Args:
    - activation: Tensor of shape (batch_size, num_channels, height, width)
    - fc_weight: Tensor of shape (1, num_channels)

    Returns:
    - A tensor of shape (batch_size, height, width) representing the CAM.

    Notes:
    This function only works for ResNet-like models with a single fully connected layer at the end,
    because we manually derived the derivative of the output with respect to the
    activation map, and thus can use it to directly compute the CAM during forward
    propagation. Otherwise, we need to use autograd.grad(logit, activation) to calculate
    the CAMs, which is slower.
    """
    b, c, h, w = activation.shape
    alpha = fc_weight[0] / (h * w)
    cam = 1 / c * torch.sum(activation * alpha.view(1, c, 1, 1), dim=1)
    return cam

def spatial_regularization_loss(features, labels, target_points):
    """
    Compute the spatial regularization loss (anchor loss).

    Args:
    - features: Tensor of shape (batch_size, num_channels, height, width)
    - labels: Binary labels for each item in the batch, shape (batch_size,)
    - target_points: Dictionary with target points for each class, e.g., {0: (x0, y0), 1: (x1, y1)}

    Returns:
    - A scalar tensor representing the spatial regularization loss.
    """
    batch_size, _, height, width = features.size()

    # Create a mesh grid to compute the centroid of activations
    x = torch.linspace(0, 1, width).float().to(features.device)
    y = torch.linspace(0, 1, height).float().to(features.device)
    Y, X = torch.meshgrid(y, x, indexing='ij')

    # Calculate the sum of activations for normalization
    total_activation = features.sum(dim=[2, 3], keepdim=True)

    # Prevent division by zero
    total_activation = torch.where(total_activation == 0, torch.ones_like(total_activation), total_activation)

    # Compute weighted sum of coordinates
    weighted_sum_x = (features * X).sum(dim=[2, 3]) / total_activation.squeeze()
    weighted_sum_y = (features * Y).sum(dim=[2, 3]) / total_activation.squeeze()

    # Extract target points based on labels
    # TODO: vectorize this operation
    targets_x = torch.tensor([target_points[label.item()][0] for label in labels]).float().to(features.device)
    targets_y = torch.tensor([target_points[label.item()][1] for label in labels]).float().to(features.device)

    # Compute the Euclidean distance with respect to target points
    distances = torch.sqrt((weighted_sum_x - targets_x) ** 2 + (weighted_sum_y - targets_y) ** 2)

    # Compute the loss
    loss = distances.mean()
    return loss


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Experiment parameters
DATASET_DIR = 'PATH_TO_DATASET'
celaba_dir = os.path.join(DATASET_DIR, 'celeba')

BATCH_SIZE_TRAIN = 32
BATCH_SIZE_TEST = 100

NUM_WORKERS = 2

TASKS = [31, 20] # [smiling, gender]

# if name is main
if __name__ == '__main__':
    # initialize argparser
    parser = argparse.ArgumentParser(description='Train an encoder for CelebA')
    parser.add_argument('--mode', type=str, default='only-encoder', help='Mode of training')
    parser.add_argument('--epochs', type=int, default=100, help='Number of epochs')
    parser.add_argument('--exp-dir', type=str, default='exps/celeb/', help='Experiment directory')
    parser.add_argument('--exp-name', type=str, default='exp_wacv2025', help='Experiment name')
    parser.add_argument('--encoder', type=str, default='unet-tiny', choices=['donn-10','donn-5', 'donn-3',
                                                                             'unet-tiny', 'unet-mini',
                                                                             'unet-small', 'unet-standard'], help='Encoder model')
    parser.add_argument('--lambda_enc', type=float, default=1.0, help='Encoder loss weight')
    parser.add_argument('--img-size', type=int, default=32, help='dataset resize size')
    args = parser.parse_args()

    exp = f"{args.exp_name}_{args.encoder}_{args.mode}_{args.lambda_enc}_{args.img_size}x{args.img_size}"

    print(exp)
    print(args)

    print('Fetching dataset from:', celaba_dir)

    if args.img_size == 32:
        transform_train = transforms.Compose([
            transforms.Resize((args.img_size, args.img_size)),
            transforms.RandomCrop((args.img_size, args.img_size), padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])

        transform_test = transforms.Compose([
            transforms.Resize((args.img_size, args.img_size)),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])
    elif args.img_size == 128:
        transform_train = transforms.Compose([
            transforms.RandomCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])

        transform_test = transforms.Compose([
            transforms.CenterCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])

    mean, std = torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])

    # normalize after encoding
    normalize = transforms.Normalize(mean=mean,std=std)
    unnormalize = transforms.Normalize(mean=(-mean / std), std=(1.0 / std))

    trainset = torchvision.datasets.CelebA(root=DATASET_DIR, split='train', target_type=["attr"], transform=transform_train)
    valset = torchvision.datasets.CelebA(root=DATASET_DIR, split='valid', target_type=["attr"], transform=transform_test)
    testset = torchvision.datasets.CelebA(root=DATASET_DIR, split='test', target_type=["attr"], transform=transform_test)


    train_subset = Subset(trainset, torch.arange(1, 50000)) # use this for more comprehensive training
    # train_subset = Subset(trainset, torch.arange(1, 5000)) # use this for short training
    valid_subset = Subset(valset, torch.arange(1, 2000))
    test_subset = Subset(testset, torch.arange(1, 1000))
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(valid_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, num_workers=NUM_WORKERS)


    device = torch.device('cuda:0')

    # setup the models
    if args.img_size == 32:
        model_T1 = ResNet18()
    else:
        from models.resnet import ResNet18 # this resnet is better for larger image sizes
        model_T1 = ResNet18()

    # replace the last layer to output a single logit
    model_T1.linear = nn.Linear(512, 1)
    # register a hook to get the last block activation
    model_T1.layer4[-1].register_forward_hook(get_activation('T1_last_block'))

    if args.img_size == 32:
        model_T2 = ResNet18()
    else:
        from models.resnet import ResNet18
        model_T2 = ResNet18()
    model_T2.linear = nn.Linear(512, 1)
    model_T2.layer4[-1].register_forward_hook(get_activation('T2_last_block'))

    exp_dir = 'exps/celeb/'
    exp_name = 'baselines'
    model_T1.load_state_dict(torch.load(os.path.join(exp_dir, exp_name, f'model_task{TASKS[0]}_baseline_{args.img_size}x{args.img_size}.pt')))
    model_T2.load_state_dict(torch.load(os.path.join(exp_dir, exp_name, f'model_task{TASKS[1]}_baseline_{args.img_size}x{args.img_size}.pt')))

    model_T1.to(device)
    model_T2.to(device)

    # evaluate model
    acc, loss = evaluate(model_T1, val_loader, device=device, criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[0])
    print(f'Eval T1 accuracy: {acc:.2f}%, {loss:.3f}')
    acc, loss = evaluate(model_T2, val_loader, device=device, criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[1])
    print(f'Eval T2 accuracy: {acc:.2f}%, {loss:.3f}')

    mode = args.mode # 'only-encoder', 'alternating'
    epochs = args.epochs

    # setup the encoder
    if args.encoder.startswith('unet'):
        size = args.encoder.split('-')[1]
        encoder = UNet(3, 3, size=size)
    elif args.encoder.startswith('donn'):
        size = int(args.encoder.split('-')[1])
        encoder = DiffractiveClassifier_RGB(num_layers=size, sys_size=args.img_size,
            distance = 0.3, pixel_size = 3.6e-5,pad = 100,
            wavelength = 5.32e-7,approx = 'Fresnel',amp_factor = 1.5)

    encoder.to(device)

    # print the number of parameters for classifier and encoder
    size_of_classifier = count_parameters(model_T1)
    size_of_encoder = count_parameters(encoder)
    ratio = size_of_encoder / size_of_classifier  * 100
    print(f"Size of classifier: {size_of_classifier/1e6:.2f}M, Size of encoder: {size_of_encoder/1e6:.2f}M, Ratio(%): {ratio:.2f}")


    if mode == 'only-encoder':
        optimizer_enc = torch.optim.AdamW(encoder.parameters(), lr=0.01)
        scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_enc, epochs)
    else:
        # alternating
        optimizer_enc = torch.optim.AdamW(encoder.parameters(), lr=0.01)

        # SGD for classifiers because it is already pretrained
        optimizer_T1 = torch.optim.SGD(model_T1.parameters(), lr=0.001, momentum=0.9)
        optimizer_T2 = torch.optim.SGD(model_T2.parameters(), lr=0.001, momentum=0.9)

        scheduler_T1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T1, epochs)
        scheduler_T2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T2, epochs)
        scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_enc, epochs)


    criterion = nn.BCEWithLogitsLoss()


    for epoch in range(epochs):
        total_loss = 0.
        total_loss_ce_T1 = 0.
        total_loss_ce_T2 = 0.

        total_loss_enc_T1 = 0.
        total_loss_enc_T2 = 0.
        total_pred_T1 = []
        total_pred_T2 = []

        counts = 0
        encoder.train()
        model_T1.eval()
        model_T2.eval()

        for batch_idx, (images, labels) in enumerate(train_loader):
            labels_T1, labels_T2 = labels[:, TASKS[0]].float(), labels[:, TASKS[1]].float()
            images, labels_T1, labels_T2 = images.to(device), labels_T1.to(device), labels_T2.to(device)

            encoded = encoder(images)
            logits_T1 = model_T1(normalize(encoded)).flatten()
            logits_T2 = model_T2(normalize(encoded)).flatten()

            pred_T1 = ((torch.sigmoid(logits_T1) > 0.5).float() == labels_T1).detach().cpu().tolist()
            pred_T2 = ((torch.sigmoid(logits_T2) > 0.5).float() == labels_T2).detach().cpu().tolist()

            loss_ce_T1 = criterion(logits_T1, labels_T1)
            loss_ce_T2 = criterion(logits_T2, labels_T2)

            cam_T1 = calc_cam_resnet_binary(activation['T1_last_block'], model_T1.linear.weight).unsqueeze(1)
            cam_T2 = calc_cam_resnet_binary(activation['T2_last_block'], model_T2.linear.weight).unsqueeze(1)

            # we do F.relu(cam_T1*-1.0) to get the cam for the negative class since
            # we are doing binary classification, and we add them together because we
            # want to push both positive and negative features for one task to one corner
            cam_T1_mean = F.relu(cam_T1) + F.relu(cam_T1*-1.0)
            cam_T2_mean = F.relu(cam_T2) + F.relu(cam_T2*-1.0)

            loss_enc_T1 = spatial_regularization_loss(cam_T1_mean,
                                            labels_T1, target_points={
                                                0: (0, 0), 1: (0, 0)})
            loss_enc_T2 = spatial_regularization_loss(cam_T2_mean,
                                            labels_T2, target_points={
                                                0: (1, 1), 1: (1, 1)})

            if (epoch == 0) and (batch_idx == 0):
                print(f"First loss: {loss_ce_T1.item():.3f} {loss_ce_T2.item():.3f} (expected: {-torch.log(torch.tensor(0.5)):.3f})")

            # combine the losses
            loss = loss_ce_T1 + loss_ce_T2 + args.lambda_enc*(loss_enc_T1 + loss_enc_T2)


            if mode == 'only-encoder':
                if batch_idx == 0:
                    print('Training only encoder')
                optimizer_enc.zero_grad()
                loss.backward()
                optimizer_enc.step()

            elif mode == 'alternating':
                optimizer_enc.zero_grad()
                optimizer_T1.zero_grad()
                optimizer_T2.zero_grad()

                if epoch % 2 == 0:
                    if batch_idx == 0:
                        print('Alternating to encoder')
                    # optimizer_enc.zero_grad()
                    loss.backward()
                    optimizer_enc.step()
                else:
                    if batch_idx == 1:
                        print('Alternating to classifiers')
                    # optimizer_T1.zero_grad()
                    loss_ce_T1.backward(retain_graph=True) # error if retain_graph=False
                    optimizer_T1.step()
                    loss_ce_T2.backward()

                    optimizer_T2.step()
            else:
                raise ValueError(f"Invalid mode: {mode}")

            total_loss += loss.item()
            total_loss_ce_T1 += loss_ce_T1.item()
            total_loss_ce_T2 += loss_ce_T2.item()

            total_loss_enc_T1 += loss_enc_T1.item()
            total_loss_enc_T2 += loss_enc_T2.item()

            total_pred_T1 += pred_T1
            total_pred_T2 += pred_T2
            counts += images.shape[0]


        train_acc_T1 = torch.tensor(total_pred_T1).float().mean() * 100
        train_acc_T2 = torch.tensor(total_pred_T2).float().mean() * 100

        total_loss_ce_T1 = total_loss_ce_T1 / len(train_loader)
        total_loss_ce_T2 = total_loss_ce_T2 / len(train_loader)
        total_loss_enc_T1 = total_loss_enc_T1 / len(train_loader)
        total_loss_enc_T2 = total_loss_enc_T2 / len(train_loader)
        loss = total_loss / len(train_loader)

        print(f'{epoch}\t{train_acc_T1:.2f}% {train_acc_T2:.2f}% \t{total_loss_ce_T1:.2f} {total_loss_ce_T2:.2f}')
        print(f'\t{total_loss_enc_T1:.2f} {total_loss_enc_T2:.2f}')

        # evaluate model
        val_acc_T1, val_loss_T1, images, encoded, cam_T1_mean = evaluate(model_T1, val_loader, device=device, encoder=encoder, using_encoder=True,
                            criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[0], return_cams=True)
        print(f'\t Eval T1 accuracy: {val_acc_T1:.2f}%, loss: {val_loss_T1:.3f}')
        val_acc_T2, val_loss_T2, images, encoded, cam_T2_mean = evaluate(model_T2, val_loader, device=device, encoder=encoder, using_encoder=True,
                            criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[1], return_cams=True)
        print(f'\t Eval T2 accuracy: {val_acc_T2:.2f}%, loss: {val_loss_T2:.3f}')

        if mode == 'only-encoder':
            scheduler_enc.step()

        else:
            scheduler_enc.step()
            scheduler_T1.step()
            scheduler_T2.step()

    exp_dir = args.exp_dir
    exp_name = args.exp_name
    os.makedirs(os.path.join(exp_dir, exp_name), exist_ok=True)

    save_path_T1 = os.path.join(exp_dir, exp_name, f'model_T1_{exp}.pt')
    save_path_T2 = os.path.join(exp_dir, exp_name, f'model_T2_{exp}.pt')
    save_path_enc = os.path.join(exp_dir, exp_name, f'encoder_{exp}.pt')
    print(f"Saving Model_T1 to {save_path_T1}")
    torch.save(model_T1.cpu().state_dict(), save_path_T1)
    print(f"Saving Model_T2 to {save_path_T2}")
    torch.save(model_T2.cpu().state_dict(), save_path_T2)
    print(f"Saving Encoder to {save_path_enc}")
    torch.save(encoder.cpu().state_dict(), save_path_enc)
