import numpy as np
import matplotlib.pyplot as plt
import argparse
import cv2
import torch
from torch import nn
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim
from models.seco_inr import SecoINR
from utils import utils
import nibabel as nib
import torch.backends.cudnn as cudnn
import random


def main():

    parser = argparse.ArgumentParser(description='')

    #  Parameters
    parser.add_argument('--lr',type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
    parser.add_argument('--scheduler_b', type=float, default=0.1, help='Learning rate scheduler')
    parser.add_argument('--maxpoints', type=int, default=256*256, help='Batch size')
    parser.add_argument('--niters', type=int, default=1000, help='Number if iterations')
    parser.add_argument('--upscale_factor', type=float, default=2.0, help='Upscale factor for super-resolution')
    parser.add_argument('--p_coef',type=float, default=0.1993, help='a coeficient')
    parser.add_argument('--q_coef',type=float, default=0.0196, help='b coeficient')
    parser.add_argument('--r_coef',type=float, default=0.0588, help='c coeficient')
    parser.add_argument('--s_coef',type=float, default=0.0269, help='d coeficient')
    parser.add_argument('--beta',type=float, default=1.0, help='beta')

    args = parser.parse_args()

    # paths to data and save results
    image_path = 'data/sample_MRI_data.nii.gz'
    mask_path = 'data/sample_MRI_mask.nii.gz'
    result_path = 'results/'
    save_name = 'MRI'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

    # set random seed for reproducibility    
    seed = 42
    torch.manual_seed(seed)  # pytorch
    random.seed(seed)  # python
    np.random.seed(seed)  # numpy
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = False
    cudnn.deterministic = True 

    # Load the image and mask
    im_hr = nib.load(image_path).get_fdata().astype(np.float32)
    im_hr = cv2.cvtColor(im_hr, cv2.COLOR_BGR2GRAY)
    im_hr = (im_hr-im_hr.min())/(im_hr.max()-im_hr.min())
    mask_hr = nib.load(mask_path).get_fdata()
    num_classes = mask_hr.max()
    im_lr = cv2.resize(im_hr, None, fx=1/args.upscale_factor, fy=1/args.upscale_factor, interpolation=cv2.INTER_AREA)
    H_hr, W_hr = im_hr.shape
    H_lr, W_lr = im_lr.shape

    # Cross Entropy Loss for the Pixel Class Representation Network
    ce_loss = nn.CrossEntropyLoss()

    # Conditioner Network Configurations
    MLP_configs={'task': 'image',
                'model': 'resnet34',
                'truncated_layer':5,
                'in_channels': int(num_classes+1),  
                'hidden_channels': [64, 32, 4*4],
                'mlp_bias':0.3120,
                'activation_layer': nn.SiLU,
                'GT': torch.tensor(im_lr).to(device)[None,None,...]
                }

    # Adaptive SIREN Network Configurations
    model = SecoINR(in_features=2,
                    out_features=1, 
                    hidden_features=256,
                    hidden_layers=3,
                    first_omega_0=30.0,
                    hidden_omega_0=30.0,
                    MLP_configs = MLP_configs,
                    num_classes=num_classes,
                ).to(device)
   
    # Optimizer
    optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

    # Initialize lists for PSNR and SSIM
    psnr_values_lr = []
    psnr_values_hr = []
    mse_values_hr = []
    mse_values_lr = []
    ssim_values_hr = []
    mse_array = torch.zeros(args.niters, device=device)

    # Generate coordinate grid
    coords_lr = utils.get_coords(H_lr, W_lr, dim=2)[None, ...]
    coords_hr = utils.get_coords(H_hr, W_hr, dim=2)[None, ...]

    # Convert input image to a tensor and reshape
    gt_lr = torch.tensor(im_lr[..., None]).reshape(H_lr * W_lr, 1)[None, ...].to(device)
    gt_hr = torch.tensor(im_hr).reshape(H_hr * W_hr, 1)[None, ...].to(device)

    # Initialize a tensor for reconstructed data
    rec_lr = torch.zeros_like(gt_lr)
    seg_lr = torch.zeros_like(gt_lr)
    rec_hr = torch.zeros_like(gt_hr)
    seg_hr = torch.zeros_like(gt_hr)

    # pixel classes map
    pixclass_hr = 2*(torch.FloatTensor(mask_hr[:,:,0]).unsqueeze(0).unsqueeze(0)/num_classes) - 1
    pixclass_lr =  2*(torch.FloatTensor(cv2.resize(mask_hr, None, fx=1/args.upscale_factor, fy=1/args.upscale_factor, interpolation=cv2.INTER_NEAREST)[:,:,0]).unsqueeze(0).unsqueeze(0)/num_classes)-1
    pixclass_hr =  2*(torch.FloatTensor(cv2.resize(cv2.resize(mask_hr, None, fx=1/args.upscale_factor, fy=1/args.upscale_factor, interpolation=cv2.INTER_NEAREST), None, fx=args.upscale_factor, fy=args.upscale_factor, interpolation=cv2.INTER_NEAREST)[:,:,0]).unsqueeze(0).unsqueeze(0)/num_classes)-1  # Use the LR mask to get the HR mask

    big_mask_hr = (pixclass_hr[0,0]!=-1.).unsqueeze(-1).repeat(1,1,1).float()
    big_mask_lr = (pixclass_lr[0,0]!=-1.).unsqueeze(-1).repeat(1,1,1).float()

    pixclass_lr = pixclass_lr.permute(0, 2, 3, 1).reshape(-1, H_lr*W_lr, 1).to(device)
    pixclass_hr = pixclass_hr.permute(0, 2, 3, 1).reshape(-1, (int(H_lr*args.upscale_factor))*(int(W_lr*args.upscale_factor)), 1).to(device)

    big_mask_lr = big_mask_lr.reshape(H_lr * W_lr, 1)[None, ...].to(device)
    big_mask_hr = big_mask_hr.reshape(H_hr * W_hr, 1)[None, ...].to(device)

    gt_lr= gt_lr*big_mask_lr
    gt_hr= gt_hr*big_mask_hr

    im_lr = gt_lr[0].reshape(H_lr, W_lr, 1).detach().cpu().numpy()

    # TRAINING LOOP
    for step in range(args.niters):

        # Randomize the order of data points for each iteration
        indices = torch.randperm(H_lr*W_lr)

        # Process data points in batches
        for b_idx in range(0, H_lr*W_lr, args.maxpoints):
            b_indices = indices[b_idx:min(H_lr*W_lr, b_idx+args.maxpoints)]
            b_coords = coords_lr[:, b_indices, ...].to(device)
            b_pixclass = pixclass_lr[:, b_indices, ...].to(device)
            b_indices = b_indices.to(device)
            
            # Calculate model output
            model_output, coef, seg_output = model(b_coords)
            model_output = model_output*big_mask_lr[:, b_indices, :]  
            seg_output = seg_output*big_mask_lr[:, b_indices, :]

            # Update the reconstructed data
            with torch.no_grad():
                rec_lr[:, b_indices, :] = model_output
                seg_lr[:, b_indices, :] = seg_output.argmax(dim=-1).float().unsqueeze(-1)

            # Calculate the output loss
            output_loss = ((model_output - gt_lr[:, b_indices, :])**2).mean()
            
            # Calculate cross entropy loss
            labels = (num_classes*(b_pixclass+1)/2).to(torch.int64)[:,:,0]
            seg_loss = ce_loss(seg_output[0], labels[0])

            p_coef = coef[..., 0]
            q_coef = coef[..., 1]
            r_coef = coef[..., 2]
            s_coef = coef[..., 3]

            reg_loss = args.p_coef * torch.relu(-p_coef) + \
                    args.q_coef * torch.relu(-q_coef) + \
                    args.r_coef * torch.relu(-r_coef) + \
                    args.s_coef * torch.relu(-s_coef)
            reg_loss = reg_loss.sum()
            
            # Total loss for 'seco_inr' model
            loss = output_loss + reg_loss + args.beta*seg_loss
            print(f'iteration: {step}, loss {loss.item()}')

            # Perform backpropagation and update model parameters
            optim.zero_grad()
            loss.backward()
            optim.step()

        # Adjust learning rate using a scheduler if applicable
        if args.using_schedular:
            scheduler.step()
        
    # Calculate and log mean squared error (MSE) and PSNR
    with torch.no_grad():
        mse_array[step] = ((gt_lr - rec_lr)**2).mean().item()
        psnr_lr = -10*torch.log10(mse_array[step])
        psnr_values_lr.append(psnr_lr.item())
        
        # Super-resolution Inference
        indices_hr = torch.randperm(H_hr*W_hr)

        for b_idx in range(0, H_hr*W_hr, args.maxpoints):
            b_indices_hr = indices_hr[b_idx:min(H_hr*W_hr, b_idx+args.maxpoints)]
            b_coords_hr = coords_hr[:, b_indices_hr, ...].to(device)
            b_indices_hr = b_indices_hr.to(device)

            model_eval, _, seg_output = model(b_coords_hr)
            model_eval = model_eval*big_mask_hr[:, b_indices_hr, :] 
                
            rec_hr[:, b_indices_hr, :] = model_eval
            seg_hr[:, b_indices_hr, :] = seg_output.argmax(dim=-1).float().unsqueeze(-1)

        loss_lr = ((gt_lr - rec_lr)**2).mean()
        loss_hr = ((gt_hr - rec_hr)**2).mean()
        psnr_hr = -10*torch.log10(loss_hr)
        psnr_values_hr.append(psnr_hr.item())
        mse_values_hr.append(loss_hr.item())
        mse_values_lr.append(loss_lr.item())

        hr_pred = rec_hr[0, ...].reshape(H_hr, W_hr, 1).detach().cpu().numpy()
        
        # SSIM
        im_hr = gt_hr[0].reshape(H_hr, W_hr, 1).detach().cpu().numpy()
        ms_ssim_val = ssim(torch.tensor(im_hr[None,...]).permute(0, 3, 1, 2),
                            torch.tensor(hr_pred[None, ...]).permute(0, 3, 1, 2),
                            data_range=1, size_average=False)
        ssim_values_hr.append(ms_ssim_val[0].item())
        
    # Visualize and save the results
    subplot_info = [
        {'title': f'PSNR: {round(max(psnr_values_hr),4)}', 'image': hr_pred, 'save': 'Output',},
        {'title': 'Error', 'image': np.abs(im_hr-hr_pred), 'save': 'Error',},
        {'title': 'LR', 'image': im_lr, 'save': 'LR',},
        {'title': 'HR', 'image': im_hr, 'save': 'HR',}
    ]

    for item in subplot_info:

        plt.imshow(item['image'][:,:,0], cmap='gray', vmax=1.0, vmin=0.0)
        plt.axis('off')

        plt.savefig(f"{result_path}/{save_name}_{item['save']}.png", 
                    bbox_inches='tight',
                    pad_inches=0,
                    dpi=600,
                    facecolor='black',
                    )
        plt.close()
    
    print("Final scores:")
    print("RMSE: {:.4f} | SSIM: {:.4f} | PSNR: {:.4f}".format(mse_values_hr[-1]**0.5, ssim_values_hr[-1], psnr_values_hr[-1]))


    
if __name__ == '__main__':
    main()