import argparse
import torch
import modelFunc
import cv2
import numpy as np
import pdb

def setup_parser():
    # Initialize the argument parser
    parser = argparse.ArgumentParser(description='Configuration for machine learning model.')

    # Add arguments with their default values
    parser.add_argument('--debug', action='store_true',
                    help='Enables debug mode')
    parser.add_argument('--template', default='.',
                        help='You can set various templates in option.py')

    # Hardware specifications
    parser.add_argument('--n_threads', type=int, default=6,
                        help='number of threads for data loading')
    parser.add_argument('--cpu', action='store_true',
                        help='use cpu only')
    parser.add_argument('--n_GPUs', type=int, default=1,
                        help='number of GPUs')
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed')

    # Data specifications
    parser.add_argument('--dir_data', type=str, default='/cache/data/',
                        help='dataset directory')
    parser.add_argument('--dir_demo', type=str, default='../test',
                        help='demo image directory')
    parser.add_argument('--data_train', type=str, default='DIV2K',
                        help='train dataset name')
    parser.add_argument('--data_test', type=str, default='DIV2K',
                        help='test dataset name')
    parser.add_argument('--data_range', type=str, default='1-800/801-810',
                        help='train/test data range')
    parser.add_argument('--ext', type=str, default='sep',
                        help='dataset file extension')
    parser.add_argument('--scale', type=str, default='4',
                        help='super resolution scale')
    parser.add_argument('--patch_size', type=int, default=48,
                        help='output patch size')
    parser.add_argument('--rgb_range', type=int, default=255,
                        help='maximum value of RGB')
    parser.add_argument('--n_colors', type=int, default=3,
                        help='number of color channels to use')
    parser.add_argument('--no_augment', action='store_true',
                        help='do not use data augmentation')

    # Model specifications
    parser.add_argument('--model', default='ipt',
                        help='model name')
    parser.add_argument('--n_feats', type=int, default=64,
                        help='number of feature maps')
    parser.add_argument('--shift_mean', default=True,
                        help='subtract pixel mean from the input')
    parser.add_argument('--precision', type=str, default='single',
                        choices=('single', 'half'),
                        help='FP precision for test (single | half)')

    # Training specifications
    parser.add_argument('--reset', action='store_true',
                        help='reset the training')
    parser.add_argument('--test_every', type=int, default=1000,
                        help='do test per every N batches')
    parser.add_argument('--epochs', type=int, default=300,
                        help='number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='input batch size for training')
    parser.add_argument('--test_batch_size', type=int, default=1,
                        help='input batch size for training')
    parser.add_argument('--crop_batch_size', type=int, default=64,
                        help='input batch size for training')
    parser.add_argument('--split_batch', type=int, default=1,
                        help='split the batch into smaller chunks')
    parser.add_argument('--self_ensemble', action='store_true',
                        help='use self-ensemble method for test')
    parser.add_argument('--test_only', action='store_true',
                        help='set this option to test the model')
    parser.add_argument('--gan_k', type=int, default=1,
                        help='k value for adversarial loss')

    # Optimization specifications
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--decay', type=str, default='200',
                        help='learning rate decay type')
    parser.add_argument('--gamma', type=float, default=0.5,
                        help='learning rate decay factor for step decay')
    parser.add_argument('--optimizer', default='ADAM',
                        choices=('SGD', 'ADAM', 'RMSprop'),
                        help='optimizer to use (SGD | ADAM | RMSprop)')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='SGD momentum')
    parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
                        help='ADAM beta')
    parser.add_argument('--epsilon', type=float, default=1e-8,
                        help='ADAM epsilon for numerical stability')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay')
    parser.add_argument('--gclip', type=float, default=0,
                        help='gradient clipping threshold (0 = no clipping)')

    # Loss specifications
    parser.add_argument('--loss', type=str, default='1*L1',
                        help='loss function configuration')
    parser.add_argument('--skip_threshold', type=float, default='1e8',
                        help='skipping batch that has large error')

    # Log specifications
    parser.add_argument('--save', type=str, default='ipt/', #rm /cache
                        help='file name to save')
    parser.add_argument('--load', type=str, default='',
                        help='file name to load')
    parser.add_argument('--resume', type=int, default=0,
                        help='resume from specific checkpoint')
    parser.add_argument('--save_models', action='store_true',
                        help='save all intermediate models')
    parser.add_argument('--print_every', type=int, default=100,
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save_results', action='store_true',
                        help='save output results')
    parser.add_argument('--save_gt', action='store_true',
                        help='save low-resolution and high-resolution images together')

    #cloud
    parser.add_argument('--moxfile', type=int, default=1)
    parser.add_argument('--data_url', type=str,help='path to dataset')
    parser.add_argument('--train_url', type=str, help='train_dir')
    parser.add_argument('--pretrain', type=str, default='')
    parser.add_argument('--load_query', type=int, default=0)

    #transformer
    parser.add_argument('--patch_dim', type=int, default=3)
    parser.add_argument('--num_heads', type=int, default=12)
    parser.add_argument('--num_layers', type=int, default=12)
    parser.add_argument('--dropout_rate', type=float, default=0)
    parser.add_argument('--no_norm', action='store_true')
    parser.add_argument('--freeze_norm', action='store_true')
    parser.add_argument('--post_norm', action='store_true')
    parser.add_argument('--no_mlp', action='store_true')
    parser.add_argument('--pos_every', action='store_true')
    parser.add_argument('--no_pos', action='store_true')
    parser.add_argument('--num_queries', type=int, default=1)

    #denoise
    parser.add_argument('--denoise', action='store_true')
    parser.add_argument('--sigma', type=float, default=30)

    #derain
    parser.add_argument('--derain', action='store_true')
    parser.add_argument('--derain_test', type=int, default=1)

    #deblur
    parser.add_argument('--deblur', action='store_true')
    parser.add_argument('--deblur_test', type=int, default=1)


    # Create a list that simulates command-line arguments
    args_list = [
        '--scale', '1',
        '--test_only',
        '--sigma', '10'
    ]

    # Convert boolean strings to actual boolean values
    args_list = [False if val == 'False' else True if val == 'True' else val for val in args_list]

    # Parse the arguments from the list we created
    args = parser.parse_args(args_list)

    return args


# Get the argument object with values
IPTArgs = setup_parser()

IPTArgs.scale = [1]

#set up denoise model

DenoiseModel = modelFunc.Model(IPTArgs)

denoise_state_dict = torch.load("PretrainModel/IPT_denoise30.pt")

DenoiseModel.model.load_state_dict(denoise_state_dict,strict = False)

#set up derain model

DerainModel = modelFunc.Model(IPTArgs)

derain_state_dict = torch.load("PretrainModel/IPT_derain.pt")

DerainModel.model.load_state_dict(derain_state_dict,strict = False)

#choose device

device = torch.device('cuda')

def IPTdenoise(img, magnitude, args = IPTArgs, model = DenoiseModel, device = device):

    with torch.no_grad():

        Noisy = img

        # Convert BGR to RGB (if needed)
        Noisy = cv2.cvtColor(Noisy, cv2.COLOR_BGR2RGB)

        # Rearrange dimensions from (height, width, channels) to (channels, height, width)
        Noisy = np.transpose(Noisy, (2, 0, 1))

        # Add a batch dimension
        Noisy = np.expand_dims(Noisy, axis=0)


        # Convert to PyTorch tensor
        Noisy_tensor = torch.from_numpy(Noisy).float().to(device)

        noisy_level = 25 + magnitude
        noise = torch.randn(Noisy_tensor.size()).mul_(noisy_level).cuda()
        Noisy_tensor = (Noisy_tensor + noise).clamp(0, 255)

        #pdb.set_trace()

        Denoise_tensor = DenoiseModel(Noisy_tensor, 0)

        Denoise_tensor = Denoise_tensor.squeeze(0).cpu()

        # Transpose the dimensions to (height, width, channels)
        Denoise = Denoise_tensor.permute(1, 2, 0).numpy()

        # Adjust the data type and range
        # Make sure the values are in the range [0, 255] and then convert to uint8
        Denoise = Denoise.astype('uint8')

        # If the image is RGB and you need it in BGR for OpenCV, convert it
        Denoise = cv2.cvtColor(Denoise, cv2.COLOR_RGB2BGR)

        # Save the image
        return Denoise

    
def IPTderain(img, magnitude, args = IPTArgs, model = DerainModel, device = device):

    with torch.no_grad():

        Rain = img

        # Convert BGR to RGB (if needed)
        Rain = cv2.cvtColor(Rain, cv2.COLOR_BGR2RGB)

        # Rearrange dimensions from (height, width, channels) to (channels, height, width)
        Rain = np.transpose(Rain, (2, 0, 1))

        # Add a batch dimension
        Rain = np.expand_dims(Rain, axis=0)


        # Convert to PyTorch tensor
        Rain_tensor = torch.from_numpy(Rain).float().to(device)

        #pdb.set_trace()

        Derain_tensor = DerainModel(Rain_tensor, 0)

        Derain_tensor = Derain_tensor.squeeze(0).cpu()

        # Transpose the dimensions to (height, width, channels)
        Derain = Derain_tensor.permute(1, 2, 0).numpy()

        # Adjust the data type and range
        # Make sure the values are in the range [0, 255] and then convert to uint8
        Derain = Derain.astype('uint8')

        # If the image is RGB and you need it in BGR for OpenCV, convert it
        Derain = cv2.cvtColor(Derain, cv2.COLOR_RGB2BGR)

        # Save the image
        return Derain