import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.donn import DiffractiveClassifier_Raw, DiffractiveClassifier_RGB, DiffractiveClassifier_RGB_residual
from models.unet import UNet
import os
import argparse

print("PyTorch version:", torch.__version__)


def evaluate(model, loader, encoder=None, 
             using_encoder=False, criterion=None, 
             device='cpu', task_id=None, return_cams=False):
    
    model.layer4[-1].register_forward_hook(get_activation('last_block'))
    model.to(device)
    model.eval()
    
    if using_encoder:
        assert encoder is not None, "encoder must be provided if using_encoder is True"
        encoder.to(device)
        encoder.eval()
        
    correct = 0
    total = 0
    total_loss = 0.
    with torch.no_grad():
        for images, labels in loader:
            labels = labels[:, task_id].float()
            images, labels = images.to(device), labels.to(device)
            
            if using_encoder:
                encoded = encoder(images)
                outputs = model(normalize(encoded)).flatten()
            else:
                outputs = model(normalize(images)).flatten()
                
            if return_cams:
                cam = calc_cam_resnet_binary(activation['last_block'], model.linear.weight).unsqueeze(1)
                cam_mean = F.relu(cam) + F.relu(cam*-1.0)
                
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    loss = total_loss / len(loader)
    
    if return_cams:
        return accuracy, loss, images, encoded, cam_mean
    return accuracy, loss

activation = {}
def get_activation(name):
    """
    Get a forward hook to store the activation of a layer.
    
    Args:
    - name: Name of the layer
    
    Returns:
    - A hook function that stores the activation of the layer in the activation dictionary.
    """
    def hook(model, input, output):
        activation[name] = output
    return hook

def calc_cam_resnet_binary(activation, fc_weight):
    """
    Calculate the class activation map (CAM) for a binary classifier. 
    
    Args:
    - activation: Tensor of shape (batch_size, num_channels, height, width)
    - fc_weight: Tensor of shape (1, num_channels)
    
    Returns:
    - A tensor of shape (batch_size, height, width) representing the CAM.
    
    Notes:
    This function only works for ResNet-like models with a single fully connected layer at the end, 
    because we manually derived the derivative of the output with respect to the 
    activation map, and thus can use it to directly compute the CAM during forward 
    propagation. Otherwise, we need to use autograd.grad(logit, activation) to calculate
    the CAMs, which is slower.
    """
    b, c, h, w = activation.shape
    alpha = fc_weight[0] / (h * w)
    cam = 1 / c * torch.sum(activation * alpha.view(1, c, 1, 1), dim=1)
    return cam

def spatial_regularization_loss(features, labels, target_points):
    """
    Compute the spatial regularization loss (anchor loss).
    
    Args:
    - features: Tensor of shape (batch_size, num_channels, height, width)
    - labels: Binary labels for each item in the batch, shape (batch_size,)
    - target_points: Dictionary with target points for each class, e.g., {0: (x0, y0), 1: (x1, y1)}
    
    Returns:
    - A scalar tensor representing the spatial regularization loss.
    """
    batch_size, _, height, width = features.size()
    
    # Create a mesh grid to compute the centroid of activations
    x = torch.linspace(0, 1, width).float().to(features.device)
    y = torch.linspace(0, 1, height).float().to(features.device)
    Y, X = torch.meshgrid(y, x, indexing='ij')
    
    # Calculate the sum of activations for normalization
    total_activation = features.sum(dim=[2, 3], keepdim=True)
    
    # Prevent division by zero
    total_activation = torch.where(total_activation == 0, torch.ones_like(total_activation), total_activation)
    
    # Compute weighted sum of coordinates
    weighted_sum_x = (features * X).sum(dim=[2, 3]) / total_activation.squeeze()
    weighted_sum_y = (features * Y).sum(dim=[2, 3]) / total_activation.squeeze()
    
    # Extract target points based on labels
    # TODO: vectorize this operation
    targets_x = torch.tensor([target_points[label.item()][0] for label in labels]).float().to(features.device)
    targets_y = torch.tensor([target_points[label.item()][1] for label in labels]).float().to(features.device)
    
    # Compute the Euclidean distance with respect to target points
    distances = torch.sqrt((weighted_sum_x - targets_x) ** 2 + (weighted_sum_y - targets_y) ** 2)
    
    # Compute the loss
    loss = distances.mean()
    return loss


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 64

TASKS = [31, 20] # smiling and gender

NUM_WORKERS = 8

if __name__ == '__main__':
    # initialize argparser
    parser = argparse.ArgumentParser(description='Train a model on CelebA')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--encoder', type=str, default='unet-tiny', choices=['donn-10','donn-5', 'donn-3', 
                                                                             'donn-10-residual', 'donn-5-residual', 'donn-3-residual',
                                                                             'unet-tiny', 'unet-mini', 'unet-small', 'unet-standard', 'No-encoder'], help='Encoder model')
    parser.add_argument('--enc-img-size', type=int, default=32, help='dataset resize size')
    parser.add_argument('--classifier-img-size', type=int, default=32, help='dataset resize size')
    parser.add_argument('--pretrained-enc', type=str, default=None, help='path to pretrained encoder')
    parser.add_argument('--pretrained-T1', type=str, default=None, help='path to pretrained T1 model')
    parser.add_argument('--pretrained-T2', type=str, default=None, help='path to pretrained T2 model')
    # dataset 
    parser.add_argument('--dataset', type=str, default='celeba', help='dataset directory')
    
    args = parser.parse_args()


    
    exp = f"ATTACK_{args.encoder}_esize_{args.enc_img_size}x{args.enc_img_size}_csize_{args.classifier_img_size}x{args.classifier_img_size}"
    # mkdir exp
    os.makedirs(exp, exist_ok=True)

    print(exp)
    print(args)
    
    # Experiment parameters
    DATASET_DIR = args.dataset


    # Load the dataset
    celaba_dir = os.path.join(DATASET_DIR, 'celeba')
    print('Fetching dataset from:', celaba_dir)

    if args.enc_img_size == 32:
        transform_train = transforms.Compose([
            transforms.Resize((args.enc_img_size, args.enc_img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 

        transform_test = transforms.Compose([
            transforms.Resize((args.enc_img_size, args.enc_img_size)),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 
    elif args.enc_img_size == 128:
        transform_train = transforms.Compose([
            #transforms.RandomCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 

        transform_test = transforms.Compose([
            #transforms.CenterCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 

    mean, std = torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])
    normalize = transforms.Normalize(mean=mean,std=std)
    unnormalize = transforms.Normalize(mean=(-mean / std), std=(1.0 / std))

    trainset = torchvision.datasets.CelebA(root=DATASET_DIR, split='train', target_type=["attr"], transform=transform_train)
    valset = torchvision.datasets.CelebA(root=DATASET_DIR, split='valid', target_type=["attr"], transform=transform_test)
    testset = torchvision.datasets.CelebA(root=DATASET_DIR, split='test', target_type=["attr"], transform=transform_test)


    train_subset = Subset(trainset, torch.arange(1, 50000))
    valid_subset = Subset(valset, torch.arange(1, 2000))
    test_subset = Subset(testset, torch.arange(1, 1000))

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(valid_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)


    # retrain the classifier
    device = torch.device('cuda:0')
    if args.classifier_img_size == 32:
        from models.cifar_like.resnet import ResNet18
    else:
        from models.resnet import ResNet18
    model_T1 = ResNet18()
    model_T2 = ResNet18()
    model_T1.linear = nn.Linear(512, 1)
    model_T1.layer4[-1].register_forward_hook(get_activation('T1_last_block'))
    model_T2.linear = nn.Linear(512, 1)
    model_T2.layer4[-1].register_forward_hook(get_activation('T2_last_block'))
    
    if args.pretrained_T1 is not None:
        print(f"Loading pretrained T1 from {args.pretrained_T1}")
        model_T1.load_state_dict(torch.load(args.pretrained_T1))
    if args.pretrained_T2 is not None:
        print(f"Loading pretrained T2 from {args.pretrained_T2}")
        model_T2.load_state_dict(torch.load(args.pretrained_T2))

    model_T1.to(device)
    model_T2.to(device)

    epochs = args.epochs

    if args.encoder.startswith('unet'):
        size = args.encoder.split('-')[1]
        encoder = UNet(3, 3, size=size)
    elif args.encoder.startswith('donn'):
        size = int(args.encoder.split('-')[1])
        if 'residual' in args.encoder:
            encoder = DiffractiveClassifier_RGB_residual(num_layers=size, sys_size=args.enc_img_size,
                distance = 0.3, pixel_size = 3.6e-5,pad = 100,
                wavelength = 5.32e-7,approx = 'Fresnel',amp_factor = 1.5)
        else:
            encoder = DiffractiveClassifier_RGB(num_layers=size, sys_size=args.enc_img_size,
                distance = 0.3, pixel_size = 3.6e-5,pad = 100,
                wavelength = 5.32e-7,approx = 'Fresnel',amp_factor = 1.5)
    elif args.encoder == 'No-encoder':
        # identity encoder
        encoder =  nn.Identity()
        
    if args.pretrained_enc is not None and args.encoder != 'No-encoder':
        print(f"Loading pretrained encoder from {args.pretrained_enc}")
        encoder.load_state_dict(torch.load(args.pretrained_enc))   
    

    encoder.to(device)
    
    size_of_classifier = count_parameters(model_T1)
    size_of_encoder = count_parameters(encoder)

    ratio = size_of_encoder / size_of_classifier  * 100
    print(f"Size of classifier: {size_of_classifier/1e6:.2f}M, Size of encoder: {size_of_encoder/1e6:.2f}M, Ratio(%): {ratio:.2f}")

    criterion = nn.BCEWithLogitsLoss()

    if args.classifier_img_size == 128:
        diags = [110, 90, 70] # 1% 5% 10%
    elif args.classifier_img_size == 32:
        diags = [28, 27, 26, 22, 18]
        
    resultsTL = {
        'exp': exp,
        'diagonals': diags,
        'acc_T1': [],
        'acc_T2': []
    }
    resultsBR = {
        'exp': exp,
        'diagonals': diags,
        'acc_T1': [],
        'acc_T2': []
    }
    for diag in diags:
        for maskloc in ['TL', 'BR']:
            print('-'*10)
            print(f'diagonal: {diag}')
            if maskloc == 'TL':
                # create a mask that blocks everything except the top left corner
                mask = torch.flip(torch.ones(1, 1, args.classifier_img_size, args.classifier_img_size).triu(diag), dims=(3,)).to(device)
            elif maskloc == 'BR':
                # create a mask that blocks everything except the bottom right corner
                mask = torch.flip(torch.ones(1, 1, args.classifier_img_size, args.classifier_img_size).triu(diag), dims=(2,)).to(device)


            model_T1.to(device)
            model_T2.to(device)
            encoder.to(device)

            # reinitialize the classifier
            if args.classifier_img_size == 32:
                from models.cifar_like.resnet import ResNet18
                model_T1 = ResNet18()
                model_T2 = ResNet18()
            else:
                from models.resnet import ResNet18
                model_T1 = ResNet18()
                model_T2 = ResNet18()
            model_T1.linear = nn.Linear(512, 1)
            model_T1.layer4[-1].register_forward_hook(get_activation('T1_last_block'))
            model_T2.linear = nn.Linear(512, 1)
            model_T2.layer4[-1].register_forward_hook(get_activation('T2_last_block'))
            
            if args.pretrained_T1 is not None:
                print(f"Loading pretrained T1 from {args.pretrained_T1}")
                model_T1.load_state_dict(torch.load(args.pretrained_T1))
            if args.pretrained_T2 is not None:
                print(f"Loading pretrained T2 from {args.pretrained_T2}")
                model_T2.load_state_dict(torch.load(args.pretrained_T2))
                
            model_T1.to(device)
            model_T2.to(device)

            optimizer_T1 = torch.optim.AdamW(model_T1.parameters(), lr=0.001)
            optimizer_T2 = torch.optim.AdamW(model_T2.parameters(), lr=0.001)

            # disable encoder updates
            for param in encoder.parameters():
                param.requires_grad = False

            epochs = args.epochs
            criterion = nn.BCEWithLogitsLoss()

            scheduler_T1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T1, epochs)
            scheduler_T2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T2, epochs)

            best_acc_T1 = 0.
            best_acc_T2 = 0.
            for epoch in range(epochs):
                total_loss = 0.
                total_loss_ce_T1 = 0.
                total_loss_ce_T2 = 0.
                total_pred_T1 = []
                total_pred_T2 = []
                counts = 0
                encoder.eval()
                model_T1.train()
                model_T2.train()
                for i, (images, labels) in enumerate(train_loader):
                    labels_T1, labels_T2 = labels[:, TASKS[0]].float(), labels[:, TASKS[1]].float()
                    images, labels_T1, labels_T2 = images.to(device), labels_T1.to(device), labels_T2.to(device)

                    with torch.no_grad():
                        #resize to 32x32
                        encoded = encoder(images)
                        if args.enc_img_size == 128 and args.classifier_img_size == 32:
                            encoded = F.interpolate(encoded, size=(32, 32), mode='bilinear')
                        encoded = encoded * mask

                    
                    logits_T1 = model_T1(normalize(encoded)).flatten()
                    logits_T2 = model_T2(normalize(encoded)).flatten()

                    pred_T1 = ((torch.sigmoid(logits_T1) > 0.5).float() == labels_T1).detach().cpu().tolist()
                    pred_T2 = ((torch.sigmoid(logits_T2) > 0.5).float() == labels_T2).detach().cpu().tolist()

                    loss_ce_T1 = criterion(logits_T1, labels_T1) 
                    loss_ce_T2 = criterion(logits_T2, labels_T2) 

                    if (epoch == 0) and (i == 0):
                        print(f"First loss: {loss_ce_T1.item():.3f} {loss_ce_T2.item():.3f} (expected: {-torch.log(torch.tensor(0.5)):.3f})")
                    loss = loss_ce_T1 + loss_ce_T2

                    total_loss += loss.item()

                    total_pred_T1 += pred_T1
                    total_pred_T2 += pred_T2
                    counts += images.shape[0]


                    optimizer_T1.zero_grad()
                    loss_ce_T1.backward()
                    optimizer_T1.step()

                    optimizer_T2.zero_grad()
                    loss_ce_T2.backward()
                    optimizer_T2.step()

                acc_T1 = torch.tensor(total_pred_T1).float().mean() * 100
                acc_T2 = torch.tensor(total_pred_T2).float().mean() * 100
            
                loss = total_loss / len(train_loader)
                print(f'{epoch}\t{acc_T1:.2f}% {acc_T2:.2f}% \t{loss_ce_T1:.2f} {loss_ce_T2:.2f}')

                # evaluate model
                val_acc_T1, val_loss_T1, images, encoded, cam_T1_mean = evaluate(model_T1, val_loader, device=device, 
                                encoder=encoder, using_encoder=True,
                                criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[0], return_cams=True, mask=mask)
                
                print(f'\t Eval T1 accuracy: {val_acc_T1:.2f}%, {val_loss_T1:.3f}')

                val_acc_T2, val_loss_T2, images, encoded, cam_T2_mean  = evaluate(model_T2, val_loader, device=device, 
                                encoder=encoder, using_encoder=True,
                                criterion=nn.BCEWithLogitsLoss(), task_id=TASKS[1], return_cams=True, mask=mask)
                print(f'\t Eval T2 accuracy: {val_acc_T2:.2f}%, {val_loss_T2:.3f}')
                
                scheduler_T1.step()
                scheduler_T2.step()
                best_acc_T1 = max(best_acc_T1, val_acc_T1)
                best_acc_T2 = max(best_acc_T2, val_acc_T2)


            model_T1.to('cpu')
            model_T2.to('cpu')
            encoder.to('cpu')
        
            if maskloc == 'TL':
                resultsTL['acc_T1'].append(best_acc_T1)
                resultsTL['acc_T2'].append(best_acc_T2)
                
            elif maskloc == 'BR':
                resultsBR['acc_T1'].append(best_acc_T1)
                resultsBR['acc_T2'].append(best_acc_T2)
            print(exp)
            print('TL:', resultsTL)
            print('BR:', resultsBR)
            
            save_path_T1 = os.path.join(exp, f'model_T1_{exp}_{maskloc}_{diag}.pt')
            save_path_T2 = os.path.join(exp, f'model_T2_{exp}_{maskloc}_{diag}.pt')
            # torch.save(model_T1.state_dict(), save_path_T1)
            # torch.save(model_T2.state_dict(), save_path_T2)
