import numpy as np
from PIL import Image, ImageEnhance, ImageOps
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from RDUmodel import RDUNetFunc

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def RDUFunc(img, magnitude, model, device):


    #img_PIL = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB)) 
    #img_PIL = np.array(img_bright) 
    #img_cv2 = cv2.cvtColor(img_PIL,cv2.COLOR_RGB2BGR) 
    
    img = img.astype(np.float16) / 255
    img = np.transpose(img.copy(), (2, 0, 1)) 
    img_t = torch.from_numpy(np.expand_dims(img, 0)).float() 
    with torch.no_grad():
        img_t = img_t.to(device)
        pred_t = model(img_t)
        pred = pred_t.cpu().detach().numpy().astype(np.float16)
        pred = pred.squeeze()
        pred = np.transpose(pred, (1, 2, 0))
        pred = pred * 255
        pred = pred.astype('uint8')

    return pred


def main():
    print("Load RDUmodel")
    RDUmodel = RDUNetFunc()
    device = torch.device("cuda")
    model_path = "model_color.pth"
    state_dict = torch.load(model_path)
    RDUmodel.load_state_dict(state_dict, strict=True)
    RDUmodel = RDUmodel.to(device)
    RDUmodel.eval()
    
    img = cv2.imread("check.jpg")
    pred = RDUFunc(img, magnitude = 0, model = RDUmodel, device = device)
    cv2.imwrite("DenoiseCheck.jpg", pred)

if __name__ == '__main__':
    main()