import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

#%% Helper functions
def depth_read(filename):
    my_file = Image.open(filename)
    depth_png = np.asarray(my_file, dtype=int)[:,:,:3]
    depth = 5000.0 * (depth_png[:,:,0] + depth_png[:,:,1]*256.0 + depth_png[:,:,2]*256.0*256.0) / (256.0*256.0*256.0 - 1)
    depth = depth.astype(np.float32).clip(max=255.)
    depth[depth>255.] = 0.
    my_file.close()
    return depth

#%% Lidar like pattern
def sparse_depth(depth_image):
    width = 640
    height = 480
    fx = 895.6921
    fy = 895.6921
    px = 320.
    py = 240.
    u_vec = (np.arange(0,width) - px)/fx
    v_vec = (np.arange(height-1, -1, -1) - py)/fy # It starts from height-1 and goes to 0 downwards as a column vector.
                                                  # This is done to get correct values for Y axis with positive Y axis pointing upwards.

    u_img = np.repeat(np.expand_dims(u_vec, axis=0), repeats=height, axis=0)
    v_img = np.repeat(np.expand_dims(v_vec, axis=1), repeats=width, axis=1)
    
    X = u_img * depth_image
    Y = v_img * depth_image
    
    r = np.sqrt(np.square(X) + np.square(Y) + np.square(depth_image))
    elev = ((np.arccos(Y/(r + np.finfo(float).eps))) - np.pi/2.) * 180/np.pi
    azi = (np.arctan2(X, depth_image)) * 180/np.pi
    
    df_sph = pd.DataFrame({"R":r.flatten(), "E":elev.flatten(), "A":azi.flatten()})
    df_round = df_sph.round({"E":2, "A":2})
    
    azi_resolution = 0.08 # With +- 0.02 degree
    elev_resolution = 0.4 # with tolerance of +- 0.1 degree
    desired_elev_list = np.append(np.arange(-elev_resolution, elev.min()-elev_resolution, -elev_resolution), 
                                  np.arange(0, elev.max()+elev_resolution, elev_resolution)) # Consider full FoV. If there are lidar points beyond the image dimensions, then it will be discarded later 
    desired_azi_list = np.append(np.arange(-azi_resolution, azi.min()-azi_resolution, -azi_resolution), 
                                 np.arange(0, azi.max()+azi_resolution, azi_resolution)) # Consider full FoV. If there are lidar points beyond the image dimensions, then it will be discarded later 

    df_valid_elev = pd.DataFrame()
    for elev in desired_elev_list:
        indices = np.isclose(df_round["E"], pd.Series(np.ones((df_round["E"].shape[0]))*(elev)), atol=0.1) # Choose all the points with elev equal to reference elevation angle within tolerance limits.
        df_temp = df_round[indices].copy()
        df_temp["E"] = elev
        df_valid_elev = df_valid_elev.append(df_temp)

    df_valid = pd.DataFrame()
    for azi in desired_azi_list:
        indices = np.isclose(df_valid_elev["A"], pd.Series(np.ones((df_valid_elev["A"].shape[0]))*(azi)), atol=0.02) # Choose all the points with elev equal to reference elevation angle within tolerance limits.
        df_temp = df_valid_elev[indices].copy()
        df_temp["A"] = azi
        df_valid = df_valid.append(df_temp)
        
    Y = df_valid.R * np.cos((np.pi/2) +(df_valid.E * np.pi/180))
    X = df_valid.R * np.sin(df_valid.A * np.pi/180) * np.sin((np.pi/2) +(df_valid.E * np.pi/180))
    Z = np.sqrt(np.square(df_valid.R) - np.square(X) - np.square(Y))
    
    u = (fx * X/Z) + px
    v = (fy * Y/Z) + py
    valid_pixels = np.logical_and(np.logical_and(np.logical_and(u>=0, u<width), np.logical_and(v>=0, v<height)), Z>=0)
    v = v[valid_pixels].astype('int')
    u = u[valid_pixels].astype('int')
    
    lidar_like_depth_array = np.zeros((int(height),int(width)))
    lidar_like_depth_array[(height-1-v).tolist(), u.tolist()] = depth_image[(height-1-v).tolist(), u.tolist()]
    
    return lidar_like_depth_array

#%% Custom scan patterns in camera coordinates
def scan_pattern_camera_coord(depth_image, pattern="transposed_cosine"):
    width = depth_image.shape[1]
    height = depth_image.shape[0]
    scanned_image = np.zeros((height,width))
    
    if pattern == "transposed_cosine":
        v = np.arange(0,height,0.5)
        u = ((width-1)/2) * (1 + np.cos(v/10))
        
    elif pattern == "cosine":
        u = np.arange(0,width,0.5)
        v = ((height-1)/2) * (1 + np.cos(u/10)) 
        
    v = v.astype("int32")
    u = u.astype("int32")
    scanned_image[v.tolist(), u.tolist()] = depth_image[v.tolist(), u.tolist()]
    
    return scanned_image

#%%
if __name__ == "__main__":
    root_dir = "../../dataset/synthia/train/"
    input_files = glob.glob(root_dir + "*/*/Depth/*.png", recursive=True)
    input_files = pd.DataFrame(input_files)
    
    idx = 0            
    filename = input_files.iloc[idx, 0]
    depth_image = depth_read(filename)
    output_image = sparse_depth(depth_image)
    plt.imshow(output_image, cmap="jet")   