import cv2
import torch
import numpy as np
import pdb

from depth_anything_v2.dpt import DepthAnythingV2

DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}

encoder = 'vitl' # or 'vits', 'vitb', 'vitg'

DepthAnythingmodel = DepthAnythingV2(**model_configs[encoder])
DepthAnythingmodel.load_state_dict(torch.load(f'PretrainModel/depth_anything_v2_{encoder}.pth', map_location='cpu'))
DepthAnythingmodel = DepthAnythingmodel.to(DEVICE).eval()

print("Finishing building Depth, Normal model.")

def DepthAnything(img, magnitude, model = DepthAnythingmodel):
    depth = model.infer_image(img) # HxW raw depth map in numpy
    max_value = np.max(depth)
    depth = (depth / max_value * 255).astype(np.uint8)
    depth_bgr = cv2.cvtColor(depth, cv2.COLOR_GRAY2BGR)
    return depth_bgr


def NormalAnything(img, magnitude, model = DepthAnythingmodel):
    depth = model.infer_image(img)
    h,w=np.shape(depth)
    normals=np.empty((h,w,3),dtype=float)
    d=np.empty((1,1,3))
    for i in range(h):
        for j in range(w):
            if j==1:
                dydz=(depth[i,j+1]-depth[i,j])/2.0
            elif j==w-1:
                dydz = (depth[i, j] - depth[i, j - 1]) / 2.0
            else:
                dydz = (depth[i, j + 1] - depth[i, j - 1]) / 2.0
            if i==1:
                dxdz=(depth[i+1,j]-depth[i,j])/2.0
            elif i==h-1:
                dxdz = (depth[i, j] - depth[i-1, j]) / 2.0
            else:
                dxdz = (depth[i+1, j] - depth[i-1, j]) / 2.0

            d=np.concatenate((np.expand_dims(-dxdz,(0,1,2)),np.expand_dims(-dydz,(0,1,2)),np.expand_dims(1.0,[0,1,2])),2)
            d=d/np.linalg.norm(d)
            normals[i,j,:]=d
    
    normals = ((normals + 1) / 2 * 255).astype(np.uint8)
    normals = cv2.cvtColor(normals, cv2.COLOR_RGB2BGR)
    return normals
            

if __name__ == '__main__':
    raw_img = cv2.imread('check.jpg')
    
    normal_rgb = NormalAnything(raw_img, magnitude = 0)
    
    cv2.imwrite("normal_rgb.jpg", normal_rgb)