from lyft_dataset_sdk.lyftdataset import LyftDataset
from pyquaternion import Quaternion
import numpy as np
import os, tqdm, cosy
from collections import defaultdict
from . import dataset
from georegdata.ground import FrameId, CameraId, Bin32LidarId

def rec_to_transform(rec):
    return cosy.np.Rigid(rotation=Quaternion(rec["rotation"]).rotation_matrix, translation=rec["translation"])

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants
    camera_names = ["CAM_BACK", "CAM_BACK_LEFT", "CAM_BACK_RIGHT", "CAM_FRONT", "CAM_FRONT_LEFT", "CAM_FRONT_RIGHT"]
    world_to_epsg3857 = cosy.np.ScaledRigid(translation=np.asarray([-13595406.596121203, 4500223.6031663325]), rotation=np.asarray([[-0.9595302532609804, -0.28160556293674094], [0.28160556293674094, -0.9595302532609804]]), scale=np.asarray(1.2573292495370103), dtype="float64")
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")
    origego_to_ego = cosy.np.Rigid(
        translation=np.asarray([0.0, 0.0, -0.09])
    )


    train = LyftDataset(data_path=os.path.join(path, "train"), json_path=os.path.join(path, "train", "data"), verbose=False)
    test = LyftDataset(data_path=os.path.join(path, "test"), json_path=os.path.join(path, "test", "data"), verbose=False)

    scenes = defaultdict(list)
    image_sizes = None

    jobs = []
    for split_dataset in [train, test]:
        for dataset_scene in sorted(split_dataset.scene, key=lambda dataset_scene: dataset_scene["name"]):
            jobs.append((split_dataset, dataset_scene))

    for split_dataset, dataset_scene in tqdm.tqdm(jobs, desc="Lyft"):
        recs = [split_dataset.get("sample", dataset_scene["first_sample_token"])]
        while recs[-1]["next"] != "":
            recs.append(split_dataset.get("sample", recs[-1]["next"]))

        scene_id = dataset_scene["name"]
        location = split_dataset.get("log", dataset_scene["log_token"])["location"]

        scene = []
        for index_in_scene, rec in enumerate(recs):
            lidar_rec = split_dataset.get("sample_data", rec["data"]["LIDAR_TOP"])
            timestamp = lidar_rec["timestamp"] # us
            origego_to_world_tl = rec_to_transform(split_dataset.get("ego_pose", lidar_rec["ego_pose_token"]))
            lidar_to_origego_tl = rec_to_transform(split_dataset.get("calibrated_sensor", lidar_rec["calibrated_sensor_token"]))

            cameras = []
            for camera_name in camera_names:
                camera_rec = split_dataset.get("sample_data", rec["data"][camera_name])
                original_image_shape = np.asarray([camera_rec["height"], camera_rec["width"]])

                camera_image_path = os.path.join(split_dataset.data_path, camera_rec["filename"])

                origego_to_world_tc = rec_to_transform(split_dataset.get("ego_pose", camera_rec["ego_pose_token"]))
                camera_pose_rec = split_dataset.get("calibrated_sensor", camera_rec["calibrated_sensor_token"])
                cam_to_origego_tc = rec_to_transform(camera_pose_rec)

                intr = np.array(camera_pose_rec["camera_intrinsic"])

                cameras.append(CameraId(
                    name=camera_name,
                    intr=intr,
                    resolution=original_image_shape,
                    ego_to_camera=cam_to_origego_tc.inverse() * origego_to_world_tc.inverse() * origego_to_world_tl * origego_to_ego.inverse(),
                    image_file=camera_image_path,
                ))

            lidar_file = os.path.join(split_dataset.data_path, lidar_rec["filename"])

            scene.append(FrameId(
                dataset_name="lyft",
                location=location,
                scene_id=scene_id,
                index_in_scene=index_in_scene,
                timestamp=int(timestamp),
                ego_to_world=origego_to_world_tl * origego_to_ego.inverse(),
                world_to_crs=world_to_epsg3857,
                crs=epsg3857,
                epsg4326_to_crs=epsg4326_to_epsg3857,
                cameras=cameras,
                lidars=[Bin32LidarId(
                    name="lidar",
                    file=lidar_file,
                    map=np.asarray([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]]),
                    loaded_to_ego=origego_to_ego * lidar_to_origego_tl,
                    crop_to_size="host-a011_lidar1_1233090652702363606" in lidar_file, # https://github.com/lyft/nuscenes-devkit/issues/50
                )],
            ))

            if image_sizes is None:
                image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in cameras])
        scenes[location].append(scene)

    return {
        location: dataset.Dataset(
            name="lyft",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in scenes.items()
    }
