import os, math, json, tqdm, cosy
import numpy as np
from pyquaternion import Quaternion
from collections import defaultdict
from . import dataset
from georegdata.ground import FrameId, CameraId, NpzLidarId

def json_to_transform(data):
    return cosy.np.Rigid(
        rotation=Quaternion(np.asarray([data["heading"]["w"], data["heading"]["x"], data["heading"]["y"], data["heading"]["z"]], dtype="float")).rotation_matrix,
        translation=np.asarray([data["position"]["x"], data["position"]["y"], data["position"]["z"]], dtype="float"),
    )

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants
    lidar_to_ego =   cosy.np.Rigid(translation=np.asarray([0.0, 0.0, 0.25])) \
                   * cosy.np.Rigid(rotation=Quaternion(axis=(0, 0, 1), radians=-math.pi / 2).rotation_matrix)
    skip_scenes = [ # TODO: Load all scenes, remove skip_scenes after cache (cp nuscenes/ ford)
        "004", "018", # Too short for gps registration to work # TODO: run model to determine poses
        "014", # Bridge
    ]
    original_image_shape = np.asarray([1080, 1920])
    latitude_split = 37.66016386784476 # Pandaset contains two locations, one above this latitude and one below
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")


    location_to_scenes = defaultdict(list)
    image_sizes = None
    scene_ids = [scene_id for scene_id in sorted(os.listdir(path)) if not scene_id in skip_scenes]
    for scene_id in tqdm.tqdm(scene_ids, desc="Pandaset"):
        scene_path = os.path.join(path, scene_id)
        if not os.path.isdir(scene_path):
            continue

        cameras = []
        for camera_name in os.listdir(os.path.join(scene_path, "camera")):
            camera_path = os.path.join(scene_path, "camera", camera_name)
            with open(os.path.join(camera_path, "poses.json")) as f:
                camera_to_world = [json_to_transform(data) for data in json.load(f)]
            with open(os.path.join(camera_path, "intrinsics.json")) as f:
                data = json.load(f)
                intr = np.asarray([[data["fx"], 0.0, data["cx"]], [0.0, data["fy"], data["cy"]], [0.0, 0.0, 1.0]], dtype="float64")
            cameras.append((camera_name, camera_to_world, intr))

        with open(os.path.join(scene_path, "lidar", "poses.json")) as f:
            lidar_to_world = [json_to_transform(data) for data in json.load(f)]

        with open(os.path.join(scene_path, "meta", "gps.json")) as f:
            latlons = [np.asarray([data["lat"], data["long"]]) for data in json.load(f)]

        with open(os.path.join(scene_path, "meta", "timestamps.json")) as f:
            timestamps = (np.asarray(json.load(f)).astype("float64") * (10 ** 6)).astype("uint64")

        assert len(latlons) == len(timestamps)
        assert len(latlons) == len(lidar_to_world)
        for camera_name, camera_to_world, intr in cameras:
            assert len(camera_to_world) == len(latlons)

        # Align poses with gps track
        positions_world = [t(0.0)[:2] for t in lidar_to_world]
        positions_epsg3857 = [epsg4326_to_epsg3857(latlon) for latlon in latlons]
        world_to_epsg3857 = cosy.np.ScaledRigid.least_squares(
            from_points=positions_world,
            to_points=positions_epsg3857,
        )

        location = None
        def make_frame_id(index_in_scene):
            return FrameId(
                dataset_name="pandaset",
                location=location if not location is None else "",
                scene_id=scene_id,
                index_in_scene=index_in_scene,
                timestamp=timestamps[index_in_scene],
                ego_to_world=lidar_to_world[index_in_scene] * lidar_to_ego.inverse(),
                world_to_crs=world_to_epsg3857,
                crs=epsg3857,
                epsg4326_to_crs=epsg4326_to_epsg3857,
                cameras=[CameraId(
                    name=camera_name,
                    intr=intr,
                    resolution=original_image_shape,
                    ego_to_camera=camera_to_world[index_in_scene].inverse() * lidar_to_world[index_in_scene] * lidar_to_ego.inverse(),
                    image_file=os.path.join(scene_path, "camera", camera_name, f"{index_in_scene:02d}.jpg"),
                ) for camera_name, camera_to_world, intr in cameras],
                lidars=[NpzLidarId(
                    name="lidar",
                    file=os.path.join(scene_path, "lidar", f"{index_in_scene:02d}.npz"),
                    loaded_to_ego=lidar_to_ego * lidar_to_world[index_in_scene].inverse(),
                )],
            )

        scene = []
        for index_in_scene in range(len(latlons)):
            if location is None:
                if make_frame_id(index_in_scene).latlon[0] > latitude_split:
                    location = "san francisco"
                else:
                    location = "palo alto san mateo"
            scene.append(make_frame_id(index_in_scene))

            if image_sizes is None:
                image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in scene[-1].cameras])

        location_to_scenes[location].append(scene)

    return {
        location: dataset.Dataset(
            name="pandaset",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in location_to_scenes.items()
    }
