from nuscenes import NuScenes
import numpy as np
from nuscenes.utils.geometry_utils import transform_matrix
from functools import reduce
from pyquaternion import Quaternion
import torch

@torch.no_grad()
def projectionV2(points, all_cams_from_lidar, all_cams_intrinsic, H=900, W=1600, device='cuda:0'):
    # projected_points
    # camera_x, camera_y, depth in camera coordinate, camera_id 
    num_lidar_point = points.shape[0]
    num_camera = len(all_cams_from_lidar)

    projected_points = torch.zeros((num_camera, points.shape[0], 4), device=device)

    point_padded = torch.cat([
                points.transpose(1, 0)[:3, :], 
                torch.ones(1, num_lidar_point, dtype=points.dtype, device=device)
            ], dim=0)

    # (6 x 4 x 4) x (4 x N) 
    transform_points = torch.einsum('abc,cd->abd', all_cams_from_lidar, point_padded)[:, :3, :]
    
    depths = transform_points[:, 2]

    points_2d = batch_view_points(transform_points[:, :3], all_cams_intrinsic, normalize=True)[:, :2].transpose(2, 1)
    points_2d = torch.floor(points_2d)

    points_x, points_y = points_2d[..., 0].long(), points_2d[..., 1].long()    

    valid_mask = (points_x > 0) & (points_x < W) & (points_y >0) & (points_y < H) & (depths > 0)

    valid_projected_points = projected_points[valid_mask]

    valid_projected_points[:, :2] = points_2d[valid_mask]
    valid_projected_points[:, 2] = depths[valid_mask]
    valid_projected_points[:, 3] = 1 # indicate that there is a valid projection 

    projected_points[valid_mask] = valid_projected_points 

    return projected_points
def to_tensor(x, device='cuda:0', dtype=torch.float32):
    return torch.tensor(x, dtype=dtype, device=device)

def get_obj(path):
    with open(path, 'rb') as f:
            obj = pickle.load(f)
    return obj 

def to_batch_tensor(tensor, device='cuda:0', dtype=torch.float32):
    return torch.stack([to_tensor(x, device=device, dtype=dtype) for x in tensor], dim=0)

def batch_view_points(points, view, normalize, device='cuda:0'):
    # points: batch x 3 x N 
    # view: batch x 3 x 3
    batch_size, _, nbr_points = points.shape 

    viewpad = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    viewpad[:, :view.shape[1], :view.shape[2]] = view 

    points = torch.cat((points, torch.ones([batch_size, 1, nbr_points], device=device)), dim=1)

    # (6 x 4 x 4) x (6 x 4 x N)   -> 6 x 4 x N 
    points = torch.bmm(viewpad, points)
    # points = torch.einsum('abc,def->abd', viewpad, points)

    points = points[:, :3]

    if normalize:
        # 6 x 1 x N
        points = points / points[:, 2:3].repeat(1, 3, 1)

    return points 


def read_file(path, num_point_feature=4):
    points = np.fromfile(path, dtype=np.float32).reshape(-1, 5)[:, :num_point_feature]
    return points


def get_lidar_to_image_transform(nusc, pointsensor,  camera_sensor):
    tms = []
    intrinsics = []  
    cam_paths = [] 
    CAM_CHANS = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT']
    for chan in CAM_CHANS:
        cam = camera_sensor[chan]

        # Points live in the point sensor frame. So they need to be transformed via global to the image plane.
        # First step: transform the point-cloud to the ego vehicle frame for the timestamp of the sweep.
        lidar_cs_record = nusc.get('calibrated_sensor', pointsensor['calibrated_sensor_token'])
        car_from_lidar = transform_matrix(
            lidar_cs_record["translation"], Quaternion(lidar_cs_record["rotation"]), inverse=False
        )

        # Second step: transform to the global frame.
        lidar_poserecord = nusc.get('ego_pose', pointsensor['ego_pose_token'])
        global_from_car = transform_matrix(
            lidar_poserecord["translation"],  Quaternion(lidar_poserecord["rotation"]), inverse=False,
        )

        # Third step: transform into the ego vehicle frame for the timestamp of the image.
        cam_poserecord = nusc.get('ego_pose', cam['ego_pose_token'])
        car_from_global = transform_matrix(
            cam_poserecord["translation"],
            Quaternion(cam_poserecord["rotation"]),
            inverse=True,
        )

        # Fourth step: transform into the camera.
        cam_cs_record = nusc.get('calibrated_sensor', cam['calibrated_sensor_token'])
        cam_from_car = transform_matrix(
            cam_cs_record["translation"], Quaternion(cam_cs_record["rotation"]), inverse=True
        )

        tm = reduce(
            np.dot,
            [cam_from_car, car_from_global, global_from_car, car_from_lidar],
        )

        cam_path, _, intrinsic = nusc.get_sample_data(cam['token'])

        tms.append(tm)
        intrinsics.append(intrinsic)
        cam_paths.append(cam_path )

    return tms, intrinsics, cam_paths

def get_all_cams(sample_token, nusc):
    sample = nusc.get('sample', sample_token)
    ref_chan = "LIDAR_TOP"
    chan = "LIDAR_TOP"
    ref_sd_token = sample["data"][ref_chan]
    ref_sd_rec = nusc.get("sample_data", ref_sd_token)
    ref_cams = {}
    # get all camera sensor data
    CAM_CHANS = ['CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT']
    for cam_chan in CAM_CHANS:
        camera_token = sample['data'][cam_chan]
        cam = nusc.get('sample_data', camera_token)
        ref_cams[cam_chan] = cam

    all_cams_from_lidar, all_cams_intrinsic, all_cams_path = get_lidar_to_image_transform(nusc, pointsensor=ref_sd_rec, camera_sensor=ref_cams)
    return all_cams_from_lidar, all_cams_intrinsic, all_cams_path

'''
Most of the code is adapted from the MVP repository:
https://github.com/tianweiy/MVP

@article{yin2021multimodal,
  title={Multimodal Virtual Point 3D Detection},
  author={Yin, Tianwei and Zhou, Xingyi and Kr{\"a}henb{\"u}hl, Philipp},
  journal={NeurIPS},
  year={2021},
}

'''