import os
import torch
from torchvision.transforms import functional as F
import numpy as np
import cv2
from skimage import img_as_ubyte
from MIMOmodel.MIMOUNet import MIMOUNetPlus

DeblurModel = MIMOUNetPlus()
state_dict = torch.load("PretrainModel/RealBlur.pkl")
DeblurModel.load_state_dict(state_dict['model'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DeblurModel = DeblurModel.to(device)
torch.cuda.empty_cache()
DeblurModel.eval()

print("Finishing building Deblur Model.")

def MIMOdeblur(img, magnitude, model = DeblurModel, device = device):
    with torch.inference_mode():
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.float32(img) / 255.
        img = torch.from_numpy(img).permute(2, 0, 1)
        input = img.unsqueeze(0)
        input = input.to(device)
        pred = model(input)[2]
        pred = torch.clamp(pred, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
        pred = img_as_ubyte(pred)
        return pred
    
    
if __name__ == '__main__':
    img_name = "NJU2k_000004_left.jpg"
    
    raw_img = cv2.imread("img/" + img_name)
    
    raw_img = cv2.resize(raw_img, (1024, 1024))
    
    deblur = MIMOdeblur(raw_img, magnitude = 0)
    
    cv2.imwrite("img/Deblur" + img_name, deblur)