import json
from nuscenes import NuScenes
import numpy as np
from typing import Dict, Tuple

version_ = 'v1.0-trainval'
verbose_ = 1
dataroot_ = '/project_data/ramanan/shubham/nuscenes'
nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_)

from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.detection.render import visualize_sample
from nuscenes.eval.common.utils import boxes_to_sensor
from nuscenes.eval.detection.data_classes import DetectionBox
from nuscenes.eval.common.loaders import load_prediction, load_gt, add_center_dist
from typing import Callable
from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
from nuscenes.eval.detection.data_classes import DetectionMetricData

import numpy as np
from pyquaternion import Quaternion

from nuscenes import NuScenes
from nuscenes.eval.common.data_classes import EvalBoxes
from nuscenes.eval.detection.data_classes import DetectionBox
from nuscenes.eval.detection.utils import category_to_detection_name
from nuscenes.eval.tracking.data_classes import TrackingBox
from nuscenes.eval.tracking.utils import category_to_tracking_name
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import points_in_box, view_points, transform_matrix
from nuscenes.utils.splits import create_splits_scenes
from nuscenes.eval.detection.constants import TP_METRICS
import project

gt_boxes = load_gt(nusc_, 'val', DetectionBox, verbose=True)
gt_boxes = add_center_dist(nusc_, gt_boxes)

sample_tokens = gt_boxes.sample_tokens

def load_json(filepath):
    with open(filepath, "r") as f:
        data = json.load(f)
    return data

def add_center_dist_(nusc: NuScenes,
                    eval_boxes: EvalBoxes):
    """
    Adds the cylindrical (xy) center distance from ego vehicle to each box.
    :param nusc: The NuScenes instance.
    :param eval_boxes: A set of boxes, either GT or predictions.
    :return: eval_boxes augmented with center distances.
    """
    for sample_token in eval_boxes:
        sample_rec = nusc.get('sample', sample_token)
        sd_record = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
        pose_record = nusc.get('ego_pose', sd_record['ego_pose_token'])

        for box in eval_boxes[sample_token]:
            # Both boxes and ego pose are given in global coord system, so distance can be calculated directly.
            # Note that the z component of the ego pose is 0.
            ego_translation = (box.translation[0] - pose_record['translation'][0],
                               box.translation[1] - pose_record['translation'][1],
                               box.translation[2] - pose_record['translation'][2])
            if isinstance(box, DetectionBox) or isinstance(box, TrackingBox):
                box.ego_translation = ego_translation
            else:
                raise NotImplementedError

    return eval_boxes

token_to_lidar_path = load_json("/project_data/ramanan/shubham/results/scene_samples_mapping/code_0_pts/sample_to_lidar_path.json")

for i,sample_token in enumerate(sample_tokens):
    print(i)
    lidar_path = token_to_lidar_path[sample_token]
    points = project.read_file(lidar_path)
    projected_points = project.projectionV2(project.to_tensor(points), project.to_batch_tensor(all_cams_from_lidar), project.to_batch_tensor(all_cams_intrinsic)).cpu().numpy()
    with open(f'project_points/{sample_token}.npy', 'wb') as f:
        np.save(f, projected_points)