import depth_pro
import cv2
import torch
import pdb
from skimage import img_as_ubyte

# Load model and preprocessing transform
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
AbsDepthModel, transform = depth_pro.create_model_and_transforms()
AbsDepthModel.eval()

print("Finishing building Absolute depth estimation model.")

def AbsDepthEstimation(img, magnitude, model = AbsDepthModel, device = device, transform = transform):
    with torch.inference_mode():
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        image = transform(img)
        prediction = model.infer(image, f_px = None)
        depth = prediction["depth"]  # Depth in [m].
        min_val = depth.min()
        max_val = depth.max()
        normalized_tensor = (depth - min_val) / (max_val - min_val) * 255.0
        depth_image = normalized_tensor.byte().numpy()
        color_depth_image = cv2.applyColorMap(depth_image, cv2.COLORMAP_JET)
        color_depth_bgr = cv2.cvtColor(color_depth_image, cv2.COLOR_RGB2BGR)
        color_depth = img_as_ubyte(color_depth_bgr)
        return color_depth
        
        
        
        
if __name__ == '__main__':
    img_name = "100014.png"
    
    raw_img = cv2.imread(img_name)
    
    raw_img = cv2.resize(raw_img, (1024, 1024))
    
    AbsDepth = AbsDepthEstimation(raw_img, magnitude = 0)
    
    cv2.imwrite("AbsDepth_" + img_name, AbsDepth)

