import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.donn import DiffractiveClassifier_RGB, DiffractiveClassifier_RGB_residual
from models.unet import UNet
import os
import argparse


print("PyTorch version:", torch.__version__)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 64
TASKS = [31, 20] # smiling and gender

NUM_WORKERS = 8

# if name is main
if __name__ == '__main__':
    # initialize argparser
    parser = argparse.ArgumentParser(description='Train a student DONN to approximate a UNet teacher')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')
    parser.add_argument('--encoder', type=str, default='unet-tiny', choices=['donn-10','donn-5', 'donn-3', 
                                                                             'donn-10-residual', 'donn-5-residual', 
                                                                             'unet-tiny', 'unet-mini', 'unet-small', 'unet-standard'], help='Encoder model')
    parser.add_argument('--img-size', type=int, default=128, help='dataset resize size')
    parser.add_argument('--loss-func', type=str, default='mse', choices=['l1', 'mse', 'cosine'], help='Loss function for encoder')
    parser.add_argument('--mask-size', type=int, default=None, help='mask size for donn')
    parser.add_argument('--no_clamp_student', action='store_true', default=False, help='clamp encoders output')
    parser.add_argument('--no_clamp_teacher', action='store_true', default=False, help='clamp encoders output')
    parser.add_argument('--teacher', type=str, default=None, help='used teacher encoder')
    parser.add_argument('--teacher_is_32', action='store_true', default=False, help='clamp encoders output')
    parser.add_argument('--dataset', type=str, default='celeba', help='dataset directory')
    args = parser.parse_args()
    
    exp = f"APRROX_{args.encoder}_{args.img_size}x{args.img_size}_{args.loss_func}_mask{args.mask_size}_no_clamp{args.no_clamp_student}_{args.no_clamp_teacher}_teacher_{args.teacher.split('/')[-1]}"
        
    print(exp)
    print(args)

    os.makedirs(exp, exist_ok=True)
    
    # Experiment parameters
    # DATASET_DIR = './celeba/'
   
    DATASET_DIR = args.dataset

    results = torch.zeros(len(TASKS), len(TASKS))

    # Load the dataset
    celaba_dir = os.path.join(DATASET_DIR, 'celeba')
    print('Fetching dataset from:', celaba_dir)

    if args.img_size == 32:
        raise ValueError('Image size 32x32 not supported, DONN will not work with this size. Use 128x128')
    elif args.img_size == 128:
        transform_train = transforms.Compose([
            #transforms.RandomCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 

        transform_test = transforms.Compose([
            # transforms.CenterCrop((178, 178)),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
        #     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]) 

    mean, std = torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])
    normalize = transforms.Normalize(mean=mean,std=std)
    unnormalize = transforms.Normalize(mean=(-mean / std), std=(1.0 / std))

    trainset = torchvision.datasets.CelebA(root=DATASET_DIR, split='train', target_type=["attr"], transform=transform_train, download=True)
    valset = torchvision.datasets.CelebA(root=DATASET_DIR, split='valid', target_type=["attr"], transform=transform_test, download=True)
    testset = torchvision.datasets.CelebA(root=DATASET_DIR, split='test', target_type=["attr"], transform=transform_test, download=True)

    train_subset = Subset(trainset, torch.arange(1, 50000))
    # train_subset = Subset(trainset, torch.arange(1, 5000))
    valid_subset = Subset(valset, torch.arange(1, 2000))
    test_subset = Subset(testset, torch.arange(1, 1000))

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(valid_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
    test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE_TEST, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)

    # retrain the classifier
    device = torch.device('cuda:0')

    epochs = args.epochs

    if args.encoder.startswith('unet'):
        size = args.encoder.split('-')[1]
        encoder = UNet(3, 3, size=size)
    elif args.encoder.startswith('donn'):
        size = int(args.encoder.split('-')[1])
        if 'residual' in args.encoder:
            encoder = DiffractiveClassifier_RGB_residual(num_layers=size, sys_size=args.img_size,
                distance = 0.3, pixel_size = 3.6e-5,pad = 100,
                wavelength = 5.32e-7,approx = 'Fresnel',amp_factor = 1.5)
        else:
            encoder = DiffractiveClassifier_RGB(num_layers=size, sys_size=args.img_size,
                distance = 0.3, pixel_size = 3.6e-5,pad = 100,
                wavelength = 5.32e-7,approx = 'Fresnel',amp_factor = 1.5)
                
    encoder_teacher = UNet(3, 3, size='tiny')
    encoder_teacher.load_state_dict(torch.load(args.teacher))
    
    encoder_teacher.to(device)
    encoder.to(device)
    
    if args.encoder.startswith('unet'):
        optimizer_enc = torch.optim.AdamW(encoder.parameters(), lr=args.lr)
    elif args.encoder.startswith('donn'):
        optimizer_enc = torch.optim.Adam(encoder.parameters(), lr=args.lr)
    scheduler_enc = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_enc, epochs)

    if args.loss_func == 'l1':
        criterion = nn.L1Loss()
    elif args.loss_func == 'mse':
        criterion = nn.MSELoss()
    elif args.loss_func == 'cosine':
        cos_criterion = nn.CosineEmbeddingLoss()
        def cosine_loss(x1, x2):
            # get N
            N = x1.shape[0]
            # reshape to N, -1
            x1 = x1.view(N, -1)
            x2 = x2.view(N, -1)
            # cosimilarity
            cos_criterion(x1, x2, torch.ones(N).to(device))
        criterion = cosine_loss
        
    if args.mask_size is not None:
        maskTL  = torch.flip(torch.triu(torch.ones(1, 1, args.img_size, args.img_size), args.mask_size), dims=(3,))
        maskBR  = torch.flip(torch.triu(torch.ones(1, 1, args.img_size, args.img_size), args.mask_size), dims=(2,))
        mask = (maskTL + maskBR) # combine both masks
        mask = mask.to(device)
        
    for epoch in range(epochs):
        total_loss = 0.
        encoder.train()
        encoder_teacher.eval()
        
        for batch_idx, (images, labels) in enumerate(train_loader):
            labels_T1, labels_T2 = labels[:, TASKS[0]].float(), labels[:, TASKS[1]].float()
            images, labels_T1, labels_T2 = images.to(device), labels_T1.to(device), labels_T2.to(device)

            encoded = encoder(images, no_clamp=args.no_clamp_student)

            with torch.no_grad():
                if args.teacher_is_32:
                        #resize image to 32x32
                        images = F.interpolate(images, size=(32, 32), mode='bilinear', align_corners=False)
                        encoded_teacher = encoder_teacher(images, no_clamp=args.no_clamp_teacher)
                        # back to 128 x 128
                        encoded_teacher = F.interpolate(encoded_teacher, size=(128, 128), mode='bilinear', align_corners=False)
                else:
                    encoded_teacher = encoder_teacher(images, no_clamp=args.no_clamp_teacher)
                    # blur it by downsampling and upsampling
                    encoded_teacher = F.interpolate(encoded_teacher, size=(32, 32), mode='bilinear', align_corners=False)
                    encoded_teacher = F.interpolate(encoded_teacher, size=(128, 128), mode='bilinear', align_corners=False)
                
            if args.mask_size is not None:
                encoded = encoded * mask
                encoded_teacher = encoded_teacher * mask
            
            loss = criterion(encoded, encoded_teacher)
            
            optimizer_enc.zero_grad()
            loss.backward()
            optimizer_enc.step()
            
            total_loss += loss.item()
    
        loss = total_loss / len(train_loader)
        print(f'{epoch}\t{loss:.5f}')

        scheduler_enc.step()

    save_path_enc = os.path.join(exp, f'donn_{exp}.pt')
    print(f"Saving Encoder to {save_path_enc}")
    torch.save(encoder.cpu().state_dict(), save_path_enc)
         

    