import os, math, yaml, tqdm, cosy
import numpy as np
from pyquaternion import Quaternion
from collections import defaultdict
from functools import partial
from . import dataset
from georegdata.ground import FrameId, CameraId, NpzLidarId, Imu

def rostransform_to_transform(pose):
    return cosy.np.Rigid(
        translation=np.asarray([pose["translation"]["x"], pose["translation"]["y"], pose["translation"]["z"]]),
        rotation=Quaternion(np.asarray([pose["rotation"]["w"], pose["rotation"]["x"], pose["rotation"]["y"], pose["rotation"]["z"]])).rotation_matrix,
    )

def load(path, train_frames_only=False, valid_starts_only=False, **kwargs):
    # Constants
    good_frames = {
        "2017-08-04-V2-Log1": list(range(0, 1803)) + list(range(1855, 2050)) + list(range(2095, 2345)) + list(range(2405, 6475)) + list(range(6570, 6610))  + list(range(6675, 6685)),
        "2017-08-04-V2-Log2": list(range(500, 950)) + list(range(1150, 1240)) + list(range(1350, 2400)) + list(range(2840, 5440)) + list(range(5490, 5620)) + list(range(5670, 5820)) + list(range(5870, 6820)) + list(range(6890, 9780)),
        "2017-08-04-V2-Log3": list(range(0, 80)) + list(range(190, 2600)) + list(range(2690, 8506)),
        "2017-08-04-V2-Log4": list(range(0, 6615)),
        "2017-08-04-V2-Log5": list(range(0, 2700)) + list(range(2800, 8125)),
        "2017-08-04-V2-Log6": list(range(0, 5739)),
        "2017-08-04-V3-Log1": list(range(0, 1900)) + list(range(1970, 2180)) + list(range(2220, 2500)) + list(range(2550, 7040)) + list(range(7160, 7220)),
        "2017-08-04-V3-Log2": list(range(600, 980)) + list(range(1180, 1270)) + list(range(1370, 2530)) + list(range(2900, 5660)) + list(range(5720, 5850)) + list(range(5910, 6080)) + list(range(6130, 7080)) + list(range(7120, 10280)),
        "2017-08-04-V3-Log3": list(range(250, 2757)) + list(range(2870, 8992)),
        "2017-08-04-V3-Log4": list(range(0, 6856)),
        "2017-08-04-V3-Log5": list(range(0, 2850)) + list(range(2930, 8660)),
        "2017-08-04-V3-Log6": list(range(0, 5866)),
        "2017-10-26-V2-Log1": list(range(0, 1550)) + list(range(1600, 1880)) + list(range(1940, 6080)) + list(range(6190, 6250)) + list(range(7060, 7758)),
        "2017-10-26-V2-Log2": list(range(0, 170)) + list(range(325, 410)) + list(range(500, 2100)) + list(range(3020, 3070)) + list(range(3400, 6250)) + list(range(6310, 6450)) + list(range(6510, 6670)) + list(range(6720, 7645)) + list(range(7700, 10813)),
        "2017-10-26-V2-Log3": list(range(0, 40)) + list(range(150, 2240)) + list(range(2320, 8694)),
        "2017-10-26-V2-Log4": list(range(0, 4980)),
        "2017-10-26-V2-Log5": list(range(0, 3800)) + list(range(3870, 4405)),
        "2017-10-26-V2-Log6": list(range(0, 4880)),
    }

    assert not (train_frames_only and valid_starts_only)
    location_to_dataset = load_(path, **kwargs)
    dataset = location_to_dataset["detroit"]

    if valid_starts_only:
        for i in range(len(dataset.scenes)):
            scene_id = dataset.scenes[i][0].scene_id
            if scene_id == "2017-08-04-V2-Log2":
                dataset.scenes[i] = dataset.scenes[i][500:]
            if scene_id == "2017-08-04-V3-Log2":
                dataset.scenes[i] = dataset.scenes[i][500:]
            if scene_id == "2017-08-04-V2-Log1":
                dataset.scenes[i] = dataset.scenes[i][:-500]
            if scene_id == "2017-08-04-V3-Log1":
                dataset.scenes[i] = dataset.scenes[i][::-500]
            if scene_id == "2017-10-26-V2-Log1":
                dataset.scenes[i] = dataset.scenes[i][:-1300]

    if train_frames_only:
        for i in range(len(dataset.scenes)):
            scene_id = dataset.scenes[i][0].scene_id
            if scene_id in good_frames:
                dataset.scenes[i] = [dataset.scenes[i][j] for j in good_frames[scene_id]]

    return location_to_dataset

@dataset.cached_load
@dataset.finetuned_load
def load_(path):
    # Constants
    world_to_epsg3857 = cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(np.asarray([42.294319, -83.223275]))
    oldworld_to_world = cosy.np.Rigid(rotation=np.asarray([[0, 1, 0], [1, 0, 0], [0, 0, -1]], dtype="float32"))
    body_to_ego =   cosy.np.Rigid(translation=np.asarray([0.0, 0.0, 0.29])) \
                  * cosy.np.Rigid(rotation=Quaternion(axis=(1, 0, 0), radians=math.pi).rotation_matrix)
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")

    transforms = defaultdict(dict)
    camera_params = defaultdict(dict)
    for config_path in [os.path.join(path, d) for d in os.listdir(path) if d.startswith("V")]:
        car = os.path.basename(config_path)
        for config_file in [os.path.join(config_path, f) for f in os.listdir(config_path) if not "Intensity" in f and f.endswith(".yaml")]:
            with open(config_file, "r") as f:
                config = yaml.safe_load(f)
                if "transform" in config:
                    transform = rostransform_to_transform(config["transform"])
                    parent_frame = config["header"]["frame_id"]
                    child_frame = config["child_frame_id"]
                    transforms[car][(child_frame, parent_frame)] = transform
                if "height" in config:
                    resolution = (int(config["height"]), int(config["width"]))
                    intr = np.asarray(config["K"]).reshape([3, 3])
                    camera = os.path.basename(config_file)
                    assert camera.startswith("camera") and "Intrinsics" in camera
                    camera = camera[6:-15]
                    camera = {
                        "Center": "Center",
                        "FrontLeft": "FL",
                        "FrontRight": "FR",
                        "RearLeft": "RL",
                        "RearRight": "RR",
                        "SideLeft": "SL",
                        "SideRight": "SR",
                    }[camera]
                    camera_params[car][camera] = (intr, resolution)

    scenes = []
    for scene_path in tqdm.tqdm([os.path.join(path, d) for d in sorted(os.listdir(path)) if "Log" in d and not d.endswith(".bag")], desc="Ford-AVData"):
        car = os.path.basename(scene_path)[11:-5]

        imu = np.load(os.path.join(scene_path, "imu.npz")) # No covariance given in dataset
        imu_timestamps = imu["timestamps"]
        # imu_orientation = np.asarray([cosy.np.Rigid(rotation=o) for o in imu["orientation"]])
        imu_angular_velocity = imu["angular_velocity"]
        imu_linear_acceleration = imu["linear_acceleration"]

        # imu_orientation = partial(
        #     cosy.np.lerp,
        #     xs=imu_timestamps,
        #     ys=imu_orientation,
        #     lerp2=cosy.np.Rigid.slerp,
        # )
        imu_angular_velocity = partial(
            cosy.np.lerp,
            xs=imu_timestamps,
            ys=imu_angular_velocity,
        )
        imu_linear_acceleration = partial(
            cosy.np.lerp,
            xs=imu_timestamps,
            ys=imu_linear_acceleration,
        )
        imu_to_body = transforms[car][(f"imu", "body")]
        imu_to_ego = body_to_ego * imu_to_body
        # TODO: imu data smoothing, since it is subsampled later? also other datasets?

        pose_ground_truth = np.load(os.path.join(scene_path, "pose_ground_truth.npz"))
        poses, poses_timestamps = pose_ground_truth["transforms"], pose_ground_truth["timestamps"]
        poses = [cosy.np.Rigid.from_matrix(m).inverse() for m in poses]
        body_to_oldworld = partial(
            cosy.np.lerp,
            xs=poses_timestamps,
            ys=poses,
            lerp2=cosy.np.Rigid.slerp,
        )

        lidar_names = os.listdir(os.path.join(scene_path, "lidar"))
        camera_names = os.listdir(os.path.join(scene_path, "camera"))

        lidar_red_timestamps = np.asarray(sorted([int(f[:-4]) for f in os.listdir(os.path.join(scene_path, "lidar", "red"))]))
        camera_timestamps = {camera: np.asarray(sorted([int(f[:-4]) for f in os.listdir(os.path.join(scene_path, "camera", camera))])) for camera in camera_names}

        lidar_timestamps = []
        for lidar in lidar_names:
            lidar_this_timestamps = np.asarray(sorted([int(f[:-4]) for f in os.listdir(os.path.join(scene_path, "lidar", lidar))]))
            frame_indices = np.argmin(np.abs(lidar_this_timestamps[np.newaxis, :] - lidar_red_timestamps[:, np.newaxis]), axis=1) # [lidar_red]
            lidar_this_timestamps = np.take_along_axis(lidar_this_timestamps, frame_indices, axis=0)
            lidar_timestamps.append(lidar_this_timestamps)
        lidar_timestamps = np.asarray(lidar_timestamps) # [lidar, timestamps]
        assert np.all(lidar_timestamps[lidar_names.index("red"), :-1] < lidar_timestamps[lidar_names.index("red"), 1:])

        camera_timestamps = []
        for camera in camera_names:
            camera_this_timestamps = np.asarray(sorted([int(f[:-4]) for f in os.listdir(os.path.join(scene_path, "camera", camera))]))
            frame_indices = np.argmin(np.abs(camera_this_timestamps[np.newaxis, :] - lidar_red_timestamps[:, np.newaxis]), axis=1) # [lidar_red]
            camera_this_timestamps = np.take_along_axis(camera_this_timestamps, frame_indices, axis=0)
            camera_timestamps.append(camera_this_timestamps)
        camera_timestamps = np.asarray(camera_timestamps) # [camera, timestamps]

        scene = []
        for index_in_scene, (lidar_timestamps, camera_timestamps) in enumerate(zip(lidar_timestamps.T, camera_timestamps.T)):
            lidars = []
            lidar_red_timestamp = lidar_timestamps[lidar_names.index("red")]
            for lidar, lidar_timestamp in zip(lidar_names, lidar_timestamps):
                lidar_to_body = transforms[car][(f"lidar_{lidar}", "body")]
                lidars.append(NpzLidarId(
                    name="lidar",
                    file=os.path.join(scene_path, "lidar", lidar, f"{lidar_timestamp}.npz"),
                    loaded_to_ego=body_to_ego * body_to_oldworld(lidar_red_timestamp).inverse() * body_to_oldworld(lidar_timestamp) * lidar_to_body,
                ))
            ego_to_oldworld_tl = body_to_oldworld(lidar_red_timestamp) * body_to_ego.inverse()

            cameras = []
            for camera, camera_timestamp in zip(camera_names, camera_timestamps):
                body_to_oldworld_tc = body_to_oldworld(camera_timestamp)

                camera_snake_case = {
                    "Center": "center",
                    "FL": "front_left",
                    "FR": "front_right",
                    "RL": "rear_left",
                    "RR": "rear_right",
                    "SL": "side_left",
                    "SR": "side_right",
                }[camera]
                camera_to_body = transforms[car][(f"camera_{camera_snake_case}", "body")]
                ego_to_camera = camera_to_body.inverse() * body_to_oldworld_tc.inverse() * ego_to_oldworld_tl
                intr, resolution = camera_params[car][camera]

                cameras.append(CameraId(
                    name=camera,
                    intr=intr,
                    resolution=resolution,
                    ego_to_camera=ego_to_camera,
                    image_file=os.path.join(scene_path, "camera", camera, f"{camera_timestamp}.png"),
                ))

            scene.append(FrameId(
                dataset_name="ford-avdata",
                location="detroit",
                scene_id=os.path.basename(scene_path),
                index_in_scene=index_in_scene,
                timestamp=lidar_red_timestamp,
                ego_to_world=oldworld_to_world * ego_to_oldworld_tl,
                world_to_crs=world_to_epsg3857,
                crs=epsg3857,
                epsg4326_to_crs=epsg4326_to_epsg3857,
                cameras=cameras,
                lidars=lidars,
                imu=Imu(
                    angular_velocity=imu_to_ego.rotation @ imu_angular_velocity(lidar_red_timestamp),
                    linear_acceleration=imu_to_ego.rotation @ imu_linear_acceleration(lidar_red_timestamp),
                ),
            ))
        scenes.append(scene)

    return {
        "detroit": dataset.Dataset(
            name="ford-avdata",
            location="detroit",
            scenes=scenes,
            image_sizes=list(set([tuple(x for x in camera_id.load_resolution()) for camera_id in scenes[0][0].cameras])),
        )
    }
