from collections import defaultdict
import os, json, tqdm, cosy
import numpy as np
from pyquaternion import Quaternion
import tinypl as pl
from functools import partial
from . import dataset
from georegdata.ground import FrameId, CameraId, NpzLidarId


from pyarrow import feather

def read_pose(data_frame):
    return cosy.np.Rigid(
        rotation=Quaternion(data_frame.loc[["qw", "qx", "qy", "qz"]].to_numpy().squeeze()).rotation_matrix,
        translation=data_frame.loc[["tx_m", "ty_m", "tz_m"]].to_numpy().squeeze(),
    )

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants, see https://github.com/argoai/av2-api/blob/bfb19a465814e81cd9632a7cf8869945471d8665/src/av2/geometry/utm.py
    city_latlon_origin = {
        "ATX": (30.27464237939507, -97.7404457407424),
        "DTW": (42.29993066912924, -83.17555750783717),
        "MIA": (25.77452579915163, -80.19656914449405),
        "PAO": (37.416065, -122.13571963362166),
        "PIT": (40.44177902989321, -80.01294377242584),
        "WDC": (38.889377, -77.0355047439081),
    }
    utm10n = cosy.np.proj.CRS("epsg:32610")
    utm14n = cosy.np.proj.CRS("epsg:32614")
    utm17n = cosy.np.proj.CRS("epsg:32617")
    utm18n = cosy.np.proj.CRS("epsg:32618")
    crs = {
        "ATX": utm14n,
        "DTW": utm17n,
        "MIA": utm17n,
        "PAO": utm10n,
        "PIT": utm17n,
        "WDC": utm18n,
    }
    epsg4326_to_crs = {
        "ATX": cosy.np.proj.Transformer("epsg:4326", "epsg:32614"),
        "DTW": cosy.np.proj.Transformer("epsg:4326", "epsg:32617"),
        "MIA": cosy.np.proj.Transformer("epsg:4326", "epsg:32617"),
        "PAO": cosy.np.proj.Transformer("epsg:4326", "epsg:32610"),
        "PIT": cosy.np.proj.Transformer("epsg:4326", "epsg:32617"),
        "WDC": cosy.np.proj.Transformer("epsg:4326", "epsg:32618"),
    }
    world_to_crs = {
        city: cosy.np.Rigid(translation=epsg4326_to_crs[city](city_latlon_origin[city]))
        for city in epsg4326_to_crs.keys()
    }
    origego_to_ego = cosy.np.Rigid(
        translation=np.asarray([0.0, 0.0, 0.33])
    )

    ring_camera_names = ["ring_front_center", "ring_front_left", "ring_front_right", "ring_rear_left", "ring_rear_right", "ring_side_left", "ring_side_right"]
    # stereo_camera_names = ["stereo_front_left", "stereo_front_right"]
    camera_names = ring_camera_names # + stereo_camera_names



    location_to_scenes = defaultdict(list)
    image_sizes = None

    split_paths = [os.path.join(path, "mapchange-dataset")]
    path = os.path.join(path, "sensor-dataset")
    for split in os.listdir(path):
        split_path = os.path.join(path, split)
        if os.path.isdir(split_path):
            split_paths.append(split_path)

    jobs = []
    for split_path in sorted(split_paths):
        for scene_id in sorted(os.listdir(split_path)):
            scene_path = os.path.join(split_path, scene_id)
            jobs.append((scene_id, scene_path))

    for scene_id, scene_path in tqdm.tqdm(jobs, desc="Argoverse V2"):
        # Get city name
        l = [f for f in os.listdir(os.path.join(scene_path, "map")) if f.startswith("log_map_archive")]
        assert len(l) == 1
        city_name = l[0].split("____")[1].split("_")[0]

        # Load camera timestamps
        all_cam_timestamps = {}
        for camera_name in camera_names:
            all_cam_timestamps[camera_name] = np.asarray(sorted([int(f.split(".")[0]) for f in os.listdir(os.path.join(scene_path, "sensors", "cameras", camera_name))])) # ns

        # Load poses
        pose_data_frames = feather.read_feather(os.path.join(scene_path, "city_SE3_egovehicle.feather"))
        quats = pose_data_frames.loc[:, ["qw", "qx", "qy", "qz"]].to_numpy()
        translations = pose_data_frames.loc[:, ["tx_m", "ty_m", "tz_m"]].to_numpy()
        poses = [cosy.np.Rigid(
            rotation=Quaternion(q).rotation_matrix,
            translation=t,
        ) for q, t in zip(quats, translations)]

        pose_timestamps = pose_data_frames.loc[:, ["timestamp_ns"]].to_numpy()[:, 0] # ns
        get_pose = partial(
            cosy.np.lerp,
            xs=pose_timestamps,
            ys=poses,
            lerp2=cosy.np.Rigid.slerp,
        )

        # Load camera extrinsics
        cam_to_origego_path = os.path.join(scene_path, "calibration", "egovehicle_SE3_sensor.feather")
        cam_to_origego = feather.read_feather(cam_to_origego_path)
        sensor_names = cam_to_origego["sensor_name"]
        quats = cam_to_origego.loc[:, ["qw", "qx", "qy", "qz"]].to_numpy()
        translations = cam_to_origego.loc[:, ["tx_m", "ty_m", "tz_m"]].to_numpy()
        cam_to_origego = {name: cosy.np.Rigid(
            rotation=Quaternion(q).rotation_matrix,
            translation=t,
        ) for name, q, t in zip(sensor_names, quats, translations)}

        # Load camera intrinsics
        intr_data_frames = feather.read_feather(os.path.join(scene_path, "calibration", "intrinsics.feather"))
        sensor_names = intr_data_frames["sensor_name"]
        f = intr_data_frames.loc[:, ["fx_px", "fy_px"]].to_numpy()
        c = intr_data_frames.loc[:, ["cx_px", "cy_px"]].to_numpy()
        intr = {name: np.asarray([
            [f[0], 0.0, c[0]],
            [0.0, f[1], c[1]],
            [0.0, 0.0, 1.0]
        ], dtype="float64") for name, f, c in zip(sensor_names, f, c)}
        original_image_shapes = {name: s for name, s in zip(sensor_names, intr_data_frames.loc[:, ["height_px", "width_px"]].to_numpy())}

        ply_root = os.path.join(scene_path, "sensors", "lidar")
        stream = iter(enumerate(sorted([os.path.join(ply_root, f) for f in os.listdir(ply_root) if f.endswith(".npz")])))
        stream = pl.sync(stream)
        @pl.unpack
        def load(index_in_scene, lidar_file):
            lidar_timestamp = int(os.path.basename(lidar_file).split(".")[0]) # ns
            origego_to_world_tl = get_pose(lidar_timestamp)

            camera_ids = []
            for camera_name in all_cam_timestamps.keys():
                if camera_name not in cam_to_origego:
                    print(f"Camera {camera_name} not in {cam_to_origego_path}")
                    sys.exit(-1)
                cam_frame_index = int(np.argmin(np.abs(all_cam_timestamps[camera_name] - lidar_timestamp)))
                cam_timestamp = all_cam_timestamps[camera_name][cam_frame_index]
                camera_image_path = os.path.join(scene_path, "sensors", "cameras", camera_name, f"{cam_timestamp}.jpg")
                origego_to_world_tc = get_pose(cam_timestamp)

                camera_ids.append(CameraId(
                    name=camera_name,
                    intr=intr[camera_name],
                    resolution=original_image_shapes[camera_name],
                    ego_to_camera=cam_to_origego[camera_name].inverse() * origego_to_world_tc.inverse() * origego_to_world_tl * origego_to_ego.inverse(),
                    image_file=camera_image_path,
                ))

            return FrameId(
                dataset_name="argoverse-v2",
                location=city_name,
                scene_id=scene_id,
                index_in_scene=index_in_scene,
                timestamp=lidar_timestamp // 10 ** 3,
                ego_to_world=origego_to_world_tl * origego_to_ego.inverse(),
                world_to_crs=world_to_crs[city_name],
                crs=crs[city_name],
                epsg4326_to_crs=epsg4326_to_crs[city_name],
                cameras=camera_ids,
                lidars=[NpzLidarId(
                    name="lidar",
                    file=lidar_file,
                    loaded_to_ego=origego_to_ego,
                )],
            )
        stream = pl.map(load, stream)
        stream = pl.queued(stream, workers=4, maxsize=4)
        scene = list(stream)
        scene = sorted(scene, key=lambda f: f.index_in_scene)

        if image_sizes is None:
            image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in scene[-1].cameras])

        location_to_scenes[city_name].append(scene)

    return {
        location: dataset.Dataset(
            name="argoverse-v2",
            location=location,
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes)),
        ) for location, scenes in location_to_scenes.items()
    }
