from EdgeModel import DexiNed
import numpy as np
import cv2
import torch
from torchvision import transforms

device = torch.device("cuda")

EdgeModel = DexiNed().to(device)
EdgeModel.load_state_dict(torch.load('PretrainModel/Edge.pth', map_location=device))
EdgeModel.eval()

print("Finishing building Edge model.")

def EdgeDetection(img, magnitude, model = EdgeModel, device = device):
    with torch.no_grad():
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        imgshape0 = img.shape[0]
        
        imgshape1 = img.shape[1]

        # Rearrange dimensions from (height, width, channels) to (channels, height, width)
        img = np.transpose(img, (2, 0, 1))

        # Add a batch dimension
        img = np.expand_dims(img, axis=0)


        # Convert to PyTorch tensor
        img_tensor = torch.from_numpy(img).float().to(device)
        
        transform = transforms.Compose([transforms.Resize((1024,1024), interpolation = transforms.InterpolationMode.BICUBIC, antialias = True)])
        
        img_tensor = transform(img_tensor)
        
        edge_tensor = model(img_tensor)
        
        #import pdb; pdb.set_trace()
        
        edge_tensor = edge_tensor[0].squeeze().cpu()

        # Transpose the dimensions to (height, width, channels)
        edge = edge_tensor.numpy()
        
        edge = cv2.resize(edge, (imgshape1, imgshape0))

        # Adjust the data type and range
        # Make sure the values are in the range [0, 255] and then convert to uint8
        # If the image is RGB and you need it in BGR for OpenCV, convert it
        edge = cv2.cvtColor(edge, cv2.COLOR_GRAY2BGR)
        
        edge = edge.astype('uint8')
        
        return edge
    
if __name__ == '__main__':
    raw_img = cv2.imread('check.jpg')
    
    edge_rgb = EdgeDetection(raw_img, magnitude = 0)
    
    cv2.imwrite("edgecheck.png", edge_rgb)