from collections import defaultdict
import os, json, tqdm, cosy
import numpy as np
from pyquaternion import Quaternion
import tinypl as pl
from functools import partial
from . import dataset
from georegdata.ground import FrameId, CameraId, NpzLidarId

def load_pose(path):
    with open(path, "r") as f:
        pose_data = json.load(f)
    pose = cosy.np.Rigid(
        rotation=Quaternion(np.array(pose_data["rotation"])).rotation_matrix,
        translation=np.array(pose_data["translation"]),
    )
    return pose

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants
    world_to_utm17n = {
        "PIT": cosy.np.ScaledRigid(translation=np.asarray([606.2203524421202, -100.43711466901004]), rotation=np.asarray([[0.9999999907705954, -0.00013586319921505012], [0.00013586319921505012, 0.9999999907705954]]), scale=1.0000047172396858, dtype="float64") \
                * cosy.np.Rigid(translation=(583710.0070, 4477259.9999)), # See argoverse paper
        "MIA": cosy.np.ScaledRigid(translation=np.asarray([-325.1866905745119, -120.37690092204139]), rotation=np.asarray([[0.999999994881483, 0.0001011782277592008], [-0.0001011782277592008, 0.999999994881483]]), scale=1.0000625825743747, dtype="float64") \
                * cosy.np.Rigid(translation=(580560.0088, 2850959.999)), # See argoverse paper
    }
    utm17n = cosy.np.proj.CRS("epsg:32617") # UTM Zone 17 as specified by argoverse paper
    epsg4326_to_utm17n = cosy.np.proj.Transformer("epsg:4326", "epsg:32617")
    ring_camera_names = ["ring_front_center", "ring_front_left", "ring_front_right", "ring_rear_left", "ring_rear_right", "ring_side_left", "ring_side_right"]
    stereo_camera_names = ["stereo_front_left", "stereo_front_right"]
    camera_names = ring_camera_names + stereo_camera_names
    original_image_shapes = {**{name: (1200, 1920) for name in ring_camera_names}, **{name: (2056, 2464) for name in stereo_camera_names}}
    origego_to_ego = cosy.np.Rigid(
        translation=np.asarray([0.0, 0.0, 0.32])
    )

    location_to_scenes = defaultdict(list)
    path = os.path.join(path, "argoverse-tracking")
    image_sizes = None

    jobs = []
    for split in sorted(os.listdir(path)):
        split_path = os.path.join(path, split)
        if not os.path.isdir(split_path):
            continue
        for scene_id in sorted(os.listdir(split_path)):
            scene_path = os.path.join(split_path, scene_id)
            jobs.append((scene_id, scene_path))

    for scene_id, scene_path in tqdm.tqdm(jobs, desc="Argoverse V1"):
        # Load location
        with open(os.path.join(scene_path, "city_info.json"), "r") as f:
            city_name = json.load(f)["city_name"]

        # Load camera timestamps
        all_cam_timestamps = {}
        for camera_name in camera_names:
            all_cam_timestamps[camera_name] = np.asarray(sorted([int(f.split(".")[0].split("_")[-1]) for f in os.listdir(os.path.join(scene_path, camera_name))])) # ns

        # Load poses
        pose_timestamps = np.asarray(sorted([int(f.split(".")[0].split("_")[-1]) for f in os.listdir(os.path.join(scene_path, "poses"))])) # ns
        poses = [load_pose(os.path.join(scene_path, "poses", f"city_SE3_egovehicle_{ts}.json")) for ts in pose_timestamps]
        get_pose = partial(
            cosy.np.lerp,
            xs=pose_timestamps,
            ys=poses,
            lerp2=cosy.np.Rigid.slerp,
        )

        # Load camera calibrations
        with open(os.path.join(scene_path, "vehicle_calibration_info.json"), "r") as f:
            calib_data = json.load(f)
        cameras = {}
        for camera_config in calib_data["camera_data_"]:
            camera_name = camera_config["key"][len("image_raw_"):]
            camera_config = camera_config["value"]
            origego_to_camera = cosy.np.Rigid(
                rotation=Quaternion(np.array(camera_config["vehicle_SE3_camera_"]["rotation"]["coefficients"])).rotation_matrix,
                translation=np.array(camera_config["vehicle_SE3_camera_"]["translation"])
            ).inverse()
            intr = np.asarray([
                [camera_config["focal_length_x_px_"], camera_config["skew_"], camera_config["focal_center_x_px_"]],
                [0.0, camera_config["focal_length_y_px_"], camera_config["focal_center_y_px_"]],
                [0.0, 0.0, 1.0]
            ], dtype="float64")
            cameras[camera_name] = (origego_to_camera, intr)

        ply_root = os.path.join(scene_path, "lidar")
        stream = iter(enumerate(sorted([os.path.join(ply_root, f) for f in os.listdir(ply_root) if f.endswith(".npz")])))
        stream = pl.sync(stream)
        @pl.unpack
        def load(index_in_scene, lidar_file):
            lidar_timestamp = int(os.path.basename(lidar_file).split(".")[0].split("_")[-1]) # ns
            origego_to_world_tl = load_pose(os.path.join(scene_path, "poses", f"city_SE3_egovehicle_{lidar_timestamp}.json"))

            camera_ids = []
            for camera_name, (origego_to_camera, intr) in cameras.items():
                cam_frame_index = int(np.argmin(np.abs(all_cam_timestamps[camera_name] - lidar_timestamp)))
                cam_timestamp = all_cam_timestamps[camera_name][cam_frame_index]
                camera_image_path = os.path.join(scene_path, camera_name, f"{camera_name}_{cam_timestamp}.jpg")
                origego_to_world_tc = get_pose(cam_timestamp)

                camera_ids.append(CameraId(
                    name=camera_name,
                    intr=intr,
                    resolution=original_image_shapes[camera_name],
                    ego_to_camera=origego_to_camera * origego_to_world_tc.inverse() * origego_to_world_tl * origego_to_ego.inverse(),
                    image_file=camera_image_path,
                ))

            return FrameId(
                dataset_name="argoverse-v1",
                location=city_name,
                scene_id=scene_id,
                index_in_scene=index_in_scene,
                timestamp=lidar_timestamp // 10 ** 3,
                ego_to_world=origego_to_world_tl * origego_to_ego.inverse(),
                world_to_crs=world_to_utm17n[city_name],
                crs=utm17n,
                epsg4326_to_crs=epsg4326_to_utm17n,
                cameras=camera_ids,
                lidars=[NpzLidarId(
                    name="lidar",
                    file=lidar_file,
                    loaded_to_ego=origego_to_ego,
                )],
            )
        stream = pl.map(load, stream)
        stream = pl.queued(stream, workers=12, maxsize=12)
        scene = list(stream)
        scene = sorted(scene, key=lambda f: f.index_in_scene)

        if image_sizes is None:
            image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in scene[-1].cameras])

        location_to_scenes[city_name].append(scene)

    return {
        location: dataset.Dataset(
            name="argoverse-v1",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in location_to_scenes.items()
    }
