from nuscenes.nuscenes import NuScenes
from pyquaternion import Quaternion
from nuscenes.utils.geometry_utils import view_points
import numpy as np
import os, cosy, tqdm
from collections import defaultdict
from . import dataset
from georegdata.ground import FrameId, CameraId, Bin32LidarId

def rec_to_transform(rec):
    return cosy.np.Rigid(rotation=Quaternion(rec["rotation"]).rotation_matrix, translation=rec["translation"])

def load(path, train_frames_only=False, **kwargs):
    # Constants
    bad_scenes = {
        "scene-0303": None,
        "scene-0304": None, # TODO: list(range(0, 15)), but that causes problems with finetuned_pointspace_to_crs
        "scene-0514": None,
        "scene-0586": None,
        "scene-0587": None,
        "scene-0700": None,
        "scene-0877": None,
        "scene-0884": None,
        "scene-0885": None,
        "scene-0906": None,
    }

    location_to_dataset = load_(path, **kwargs)

    if train_frames_only:
        for dataset in location_to_dataset.values():
            new_scenes = []
            for scene in dataset.scenes:
                scene_id = scene[0].scene_id
                if scene_id in bad_scenes:
                    if bad_scenes[scene_id] is None:
                        continue
                    else:
                        for frame_index in reversed(sorted(bad_scenes[scene_id])):
                            del scene[frame_index]
                new_scenes.append(scene)
            dataset.scenes = new_scenes

    return location_to_dataset

@dataset.cached_load
@dataset.finetuned_load
def load_(path):
    # Constants
    # See: https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/scripts/export_poses.py
    world_origins_latlon = {
        "boston-seaport": [42.336849169438615, -71.05785369873047],
        "singapore-onenorth": [1.2882100868743724, 103.78475189208984],
        "singapore-hollandvillage": [1.2993652317780957, 103.78217697143555],
        "singapore-queenstown": [1.2782562240223188, 103.76741409301758],
    }
    world_to_epsg3857 = {location: cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(origin) for location, origin in world_origins_latlon.items()}
    camera_names = ["CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", "CAM_FRONT", "CAM_FRONT_LEFT", "CAM_FRONT_RIGHT"]
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")
    origego_to_ego = cosy.np.Rigid(
        translation=np.asarray([0.0, 0.0, -0.01])
    )


    nusc = NuScenes(version="v1.0-trainval", dataroot=path, verbose=False)

    image_sizes = None
    location_to_scenes = defaultdict(list)
    for nusc_scene in tqdm.tqdm(sorted(nusc.scene, key=lambda nusc_scene: nusc_scene["name"]), desc="Nuscenes"):
        recs = [nusc.get("sample", nusc_scene["first_sample_token"])]
        while recs[-1]["next"] != "":
            recs.append(nusc.get("sample", recs[-1]["next"]))

        scene_id = nusc_scene["name"]
        location = nusc.get("log", nusc_scene["log_token"])["location"]

        scene = []
        for index_in_scene, rec in enumerate(recs):
            lidar_rec = nusc.get("sample_data", rec["data"]["LIDAR_TOP"])
            timestamp = lidar_rec["timestamp"] # us
            origego_to_world_tl = rec_to_transform(nusc.get("ego_pose", lidar_rec["ego_pose_token"]))
            lidar_to_origego_tl = rec_to_transform(nusc.get("calibrated_sensor", lidar_rec["calibrated_sensor_token"]))

            cameras = []
            for camera_name in camera_names:
                camera_rec = nusc.get("sample_data", rec["data"][camera_name])
                original_image_shape = np.asarray([camera_rec["height"], camera_rec["width"]])

                camera_image_path = os.path.join(nusc.dataroot, camera_rec["filename"])

                origego_to_world_tc = rec_to_transform(nusc.get("ego_pose", camera_rec["ego_pose_token"]))
                camera_pose_rec = nusc.get("calibrated_sensor", camera_rec["calibrated_sensor_token"])
                cam_to_origego_tc = rec_to_transform(camera_pose_rec)

                intr = np.array(camera_pose_rec["camera_intrinsic"])

                cameras.append(CameraId(
                    name=camera_name,
                    intr=intr,
                    resolution=original_image_shape,
                    ego_to_camera=cam_to_origego_tc.inverse() * origego_to_world_tc.inverse() * origego_to_world_tl * origego_to_ego.inverse(),
                    image_file=camera_image_path,
                ))

            scene.append(FrameId(
                dataset_name="nuscenes",
                location=location,
                scene_id=scene_id,
                index_in_scene=index_in_scene,
                timestamp=int(timestamp),
                ego_to_world=origego_to_world_tl * origego_to_ego.inverse(),
                world_to_crs=world_to_epsg3857[location],
                crs=epsg3857,
                epsg4326_to_crs=epsg4326_to_epsg3857,
                cameras=cameras,
                lidars=[Bin32LidarId(
                    name="lidar",
                    file=os.path.join(nusc.dataroot, lidar_rec["filename"]),
                    map=np.asarray([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]),
                    loaded_to_ego=origego_to_ego * lidar_to_origego_tl,
                )],
            ))

            if image_sizes is None:
                image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in cameras])

        location_to_scenes[location].append(scene)

    return {
        location: dataset.Dataset(
            name="nuscenes",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in location_to_scenes.items()
    }
