import numpy as np
import cv2
import torch
from skimage import img_as_ubyte
from ImageDemoireModel.nets import my_model

device = torch.device("cuda")

ImageDemoireingModel = my_model(en_feature_num=48,
                     en_inter_num=32,
                     de_feature_num=64,
                     de_inter_num=32,
                     sam_number=2,
                     ).to(device)

model_state_dict = torch.load("PretrainModel/uhdm_large_checkpoint.pth")

ImageDemoireingModel.load_state_dict(model_state_dict)

print("Finishing building ImageDemoire model")

def ImageDemoireing(img, magnitude, model = ImageDemoireingModel):
    with torch.inference_mode():
        #ImgShape0 = img.shape[0]
        #ImgShape1 = img.shape[1]
        #img = cv2.resize(img, (1024, 1024))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.float32(img) / 255.
        img = torch.from_numpy(img).permute(2, 0, 1)
        input = img.unsqueeze(0).cuda()
        pred, pred2, pred3 = model(input)
        pred = torch.clamp(pred, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
        pred = img_as_ubyte(pred)
        #pred = cv2.resize(pred, (ImgShape1, ImgShape0))
        return pred
    
    
    
if __name__ == '__main__':
    raw_img = cv2.imread('ycimage1.png')
    
    raw_img = cv2.resize(raw_img, (1024, 1024))
    
    ImageDemoire = ImageDemoireing(raw_img, magnitude = 0)
    
    cv2.imwrite("ImageDemoireingYcImage1.jpg", ImageDemoire)