import pdb
import numpy as np
import os
import argparse
from tqdm import tqdm
import cv2

import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import utils

from natsort import natsorted
from glob import glob
from skimage import img_as_ubyte
from collections import OrderedDict
from basicsr.models.image_restoration_model import ImageCleanModel



def self_ensemble(x, model):
    def forward_transformed(x, hflip, vflip, rotate, model):
        if hflip:
            x = torch.flip(x, (-2,))
        if vflip:
            x = torch.flip(x, (-1,))
        if rotate:
            x = torch.rot90(x, dims=(-2, -1))
        x = model(x)
        if rotate:
            x = torch.rot90(x, dims=(-2, -1), k=3)
        if vflip:
            x = torch.flip(x, (-1,))
        if hflip:
            x = torch.flip(x, (-2,))
        return x
    t = []
    for hflip in [False, True]:
        for vflip in [False, True]:
            for rot in [False, True]:
                t.append(forward_transformed(x, hflip, vflip, rot, model))
    t = torch.stack(t)
    return torch.mean(t, dim=0)

opt = OrderedDict([
    ('name', 'RetinexFormer_NTIRE'),
    ('model_type', 'ImageCleanModel'),
    ('scale', 1),
    ('num_gpu', 8),
    ('manual_seed', 100),
    ('use_amp', True),
    ('datasets', OrderedDict([
        ('train', OrderedDict([
            ('name', 'TrainSet'),
            ('type', 'Dataset_PairedImage'),
            ('dataroot_gt', 'data/NTIRE/train/target'),
            ('dataroot_lq', 'data/NTIRE/train/input'),
            ('geometric_augs', True),
            ('filename_tmpl', '{}'),
            ('io_backend', OrderedDict([('type', 'disk')])),
            ('use_shuffle', True),
            ('num_worker_per_gpu', 8),
            ('batch_size_per_gpu', 1),
            ('mini_batch_sizes', [1]),
            ('iters', [300000]),
            ('gt_size', 2000),
            ('gt_sizes', [2000]),
            ('dataset_enlarge_ratio', 1),
            ('prefetch_mode', None),
            ('phase', 'train'),
            ('scale', 1)
        ])),
        ('val', OrderedDict([
            ('name', 'ValSet'),
            ('type', 'Dataset_PairedImage'),
            ('dataroot_gt', 'data/NTIRE/mini_val/target'),
            ('dataroot_lq', 'data/NTIRE/mini_val/input'),
            ('io_backend', OrderedDict([('type', 'disk')])),
            ('phase', 'val'),
            ('scale', 1)
        ]))
    ])),
    ('network_g', OrderedDict([
        ('type', 'RetinexFormer'),
        ('in_channels', 3),
        ('out_channels', 3),
        ('n_feat', 40),
        ('stage', 1),
        ('num_blocks', [1, 2, 2])
    ])),
    ('path', OrderedDict([
        ('pretrain_network_g', None),
        ('strict_load_g', True),
        ('resume_state', None),
        ('root', '/opt/ml/code/Retinexformer-master'),
        ('results_root', '/opt/ml/code/Retinexformer-master/results/RetinexFormer_NTIRE'),
        ('log', '/opt/ml/code/Retinexformer-master/results/RetinexFormer_NTIRE'),
        ('visualization', '/opt/ml/code/Retinexformer-master/results/RetinexFormer_NTIRE/visualization')
    ])),
    ('train', OrderedDict([
        ('total_iter', 150000),
        ('warmup_iter', -1),
        ('use_grad_clip', True),
        ('scheduler', OrderedDict([
            ('type', 'CosineAnnealingRestartCyclicLR'),
            ('periods', [46000, 104000]),
            ('restart_weights', [1, 1]),
            ('eta_mins', [0.0003, 1e-06])
        ])),
        ('mixing_augs', OrderedDict([
            ('mixup', True),
            ('mixup_beta', 1.2),
            ('use_identity', True)
        ])),
        ('optim_g', OrderedDict([
            ('type', 'Adam'),
            ('lr', 0.0002),
            ('betas', [0.9, 0.999])
        ])),
        ('pixel_opt', OrderedDict([
            ('type', 'L1Loss'),
            ('loss_weight', 1),
            ('reduction', 'mean')
        ]))
    ])),
    ('val', OrderedDict([
        ('window_size', 4),
        ('val_freq', 3000.0),
        ('save_img', False),
        ('rgb2bgr', True),
        ('use_image', False),
        ('max_minibatch', 8),
        ('metrics', OrderedDict([
            ('psnr', OrderedDict([
                ('type', 'calculate_psnr'),
                ('crop_border', 0),
                ('test_y_channel', False)
            ]))
        ]))
    ])),
    ('logger', OrderedDict([
        ('print_freq', 500),
        ('save_checkpoint_freq', 1500.0),
        ('use_tb_logger', True),
        ('wandb', OrderedDict([
            ('project', 'low_light'),
            ('resume_id', None)
        ]))
    ])),
    ('dist_params', OrderedDict([
        ('backend', 'nccl'),
        ('port', 29800)
    ])),
    ('is_train', False),
    ('dist', False)
])



RetinexModel = ImageCleanModel(opt).net_g

weight_path = "PretrainModel/NTIRE.pth"

checkpoint = torch.load(weight_path)

new_checkpoint = {}

for k in checkpoint['params']:
    new_checkpoint['module.' + k] = checkpoint['params'][k]
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

RetinexModel.load_state_dict(new_checkpoint)
RetinexModel = RetinexModel.to(device)
RetinexModel = nn.DataParallel(RetinexModel)
RetinexModel.eval()

print("Finishing building Retinex model.")

def RetinexBrightUp(img, magnitude, model = RetinexModel, device = device):
    with torch.inference_mode():
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.float32(img) / 255.
        img = torch.from_numpy(img).permute(2, 0, 1)
        input = img.unsqueeze(0)
        input = input.to(device)
        pred = self_ensemble(input, RetinexModel)
        pred = torch.clamp(pred, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
        pred = img_as_ubyte(pred)
        return pred
    
    
if __name__ == '__main__':
    raw_img = cv2.imread('check.jpg')
    
    raw_img = cv2.resize(raw_img, (1024, 1024))
    
    retinex = RetinexBrightUp(raw_img, magnitude = 0)
    
    cv2.imwrite("retinexcheck.jpg", retinex)