import os, math, ciso8601, imageio, tqdm, yaml, cosy
import numpy as np
from pyquaternion import Quaternion
from functools import partial
from . import dataset
from georegdata.ground import FrameId, CameraId, Bin32LidarId, Imu

def read_timestamps(file):
    with open(file, "r") as f:
         timestamps = f.read()
    timestamps = [t.strip() for t in timestamps.split("\n")]
    timestamps = [t for t in timestamps if len(t) > 0]
    # timestamps = [datetime.datetime.strptime(t[:-3], "%Y-%m-%d %H:%M:%S.%f") for t in timestamps]
    timestamps = [ciso8601.parse_datetime(t).timestamp() for t in timestamps] # Faster version
    return timestamps

def read_trajectories_3d(path):
    # Constants
    frame_limits = {
        "2013_05_28_drive_0000_sync": (0, -1),
        "2013_05_28_drive_0002_sync": (0, -1),
        "2013_05_28_drive_0003_sync": (0, -1),
        "2013_05_28_drive_0004_sync": (70, -1),
        "2013_05_28_drive_0005_sync": (0, -1),
        "2013_05_28_drive_0006_sync": (0, -1),
        "2013_05_28_drive_0007_sync": (0, -1),
        "2013_05_28_drive_0009_sync": (0, -1),
        "2013_05_28_drive_0010_sync": (0, -1),
    }


    result = {}
    for sequence in sorted(os.listdir(os.path.join(path, "data_poses"))):
        frame_indices = sorted([int(f[:-4]) for f in os.listdir(os.path.join(path, "data_2d_raw", sequence, "image_00", "data_rect"))])
        assert len(frame_indices) == frame_indices[-1] - frame_indices[0] + 1
        frames_start_index = max(frame_indices[0], frame_limits[sequence][0])
        frames_end_index = frame_indices[-1] + 1
        if frame_limits[sequence][1] >= 0:
            frames_end_index = min(frames_end_index, frame_limits[sequence][1])

        timestamps = read_timestamps(os.path.join(path, "data_2d_raw", sequence, "image_00", "timestamps.txt"))

        # Read sparse poses
        with open(os.path.join(path, "data_poses", sequence, "poses.txt"), "r") as f:
            poses = f.read()
        poses = [p.strip() for p in poses.split("\n")]
        poses = [p for p in poses if len(p) > 0]
        def parse(str):
            tokens = str.split(" ")
            frame_index = int(tokens[0])
            m = np.asarray([float(x) for x in tokens[1:]]).reshape(3, 4)
            imu_to_world = cosy.np.Rigid(
                rotation=m[:, :3],
                translation=m[:, 3],
            )

            return (imu_to_world, frame_index)
        poses = [parse(p) for p in poses]
        frames_start_index = max(frames_start_index, poses[0][1])
        frames_end_index = min(frames_end_index, poses[-1][1] + 1)

        # Interpolate poses for every frame
        poses_interpolated = []
        for (transform1, frame_index1), (transform2, frame_index2) in zip(poses[:-1], poses[1:]):
            assert frame_index1 < frame_index2
            for frame_index in range(frame_index1, frame_index2):
                t = float(frame_index - frame_index1) / float(frame_index2 - frame_index1)
                assert 0.0 <= t and t < 1.0
                poses_interpolated.append(cosy.np.Rigid.slerp(transform1, transform2, t))
        poses_interpolated.append(poses[-1][0])
        poses_interpolated = [None] * poses[0][1] + poses_interpolated

        # Take only frames in valid range
        poses_interpolated = poses_interpolated[frames_start_index:frames_end_index]
        timestamps = timestamps[frames_start_index:frames_end_index]
        frame_indices = list(range(frames_start_index, frames_end_index))

        result[sequence] = (frame_indices, timestamps, poses_interpolated)

    return result

@dataset.cached_load
@dataset.finetuned_load
def load(path):
    # Constants
    # See: https://github.com/autonomousvision/kitti360Scripts/blob/master/kitti360scripts/devkits/convertOxtsPose/python/convertOxtsToPose.py#L17
    world_to_epsg3857 = cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(np.asarray([48.9843445, 8.4295857]))
    world_to_epsg3857 = { # TODO: this fixes only for bingmaps, googlemaps has different offsets
        "2013_05_28_drive_0000_sync": cosy.np.ScaledRigid(translation=np.asarray([1.5255124608520418, 0.29396687541157007]), rotation=np.asarray([[0.9999999999999999, 1.382167772721297e-17], [-1.382167772721297e-17, 0.9999999999999999]]), scale=np.asarray(1.000000001795384), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0002_sync": cosy.np.ScaledRigid(translation=np.asarray([4781.827932612854, -1693.9196127057076]), rotation=np.asarray([[0.999999691922719, -0.0007849550732795071], [0.0007849550732795071, 0.999999691922719]]), scale=np.asarray(1.0001524593320592), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0003_sync": cosy.np.ScaledRigid(translation=np.asarray([0.6120790672721341, 0.46767802257090807]), rotation=np.asarray([[1.0, -1.3821677774203495e-17], [1.3821677774203495e-17, 1.0]]), scale=np.asarray(0.9999999983956138), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0004_sync": cosy.np.ScaledRigid(translation=np.asarray([-123.75239311018959, -1675.9571686992422]), rotation=np.asarray([[0.9999999998011843, -1.994070040634429e-05], [1.994070040634429e-05, 0.9999999998011843]]), scale=np.asarray(1.000264432919022), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0005_sync": cosy.np.ScaledRigid(translation=np.asarray([-1003.9215552322567, -1474.8830497134477]), rotation=np.asarray([[0.9999999925449646, 0.00012210680065548104], [-0.00012210680065548104, 0.9999999925449646]]), scale=np.asarray(1.0002537848775497), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0006_sync": cosy.np.ScaledRigid(translation=np.asarray([-16444.323554050527, 190.56151132006198]), rotation=np.asarray([[0.999996699326372, 0.002569306591523215], [-0.002569306591523215, 0.999996699326372]]), scale=np.asarray(1.0003593538751618), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0007_sync": cosy.np.ScaledRigid(translation=np.asarray([-1078.3924155478599, 647.217847972177]), rotation=np.asarray([[0.99999998315617, 0.00018354198345355032], [-0.00018354198345355032, 0.99999998315617]]), scale=np.asarray(0.9999242856468162), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0009_sync": cosy.np.ScaledRigid(translation=np.asarray([1641.1508465523366, -2141.8495950819924]), rotation=np.asarray([[0.9999999533244265, -0.000305534195788601], [0.000305534195788601, 0.9999999533244265]]), scale=np.asarray(1.0002960813608466), dtype="float64") * world_to_epsg3857,
        "2013_05_28_drive_0010_sync": cosy.np.ScaledRigid(translation=np.asarray([-7644.406242471887, -18213.052549794316]), rotation=np.asarray([[0.999999706059098, 0.0007667344505946155], [-0.0007667344505946155, 0.999999706059098]]), scale=np.asarray(1.0030182058378878), dtype="float64") * world_to_epsg3857,
    }
    imu_to_ego =   cosy.np.Rigid(translation=np.asarray([0.0, 0.0, 0.91])) \
                 * cosy.np.Rigid(rotation=Quaternion(axis=(1, 0, 0), radians=math.pi).rotation_matrix) # Ego frame where z points upwards, not downwards
    epsg3857 = cosy.np.proj.CRS("epsg:3857")
    epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")


    scenes = []

    with open(os.path.join(path, "calibration", "calib_cam_to_velo.txt"), "r") as f:
        lines = f.readlines()
    mat = np.asarray([float(x) for x in lines[0].split(" ")]).reshape(3, 4)
    velo_to_cam00 = cosy.np.Rigid(rotation=mat[:3, :3], translation=mat[:3, 3]).inverse()

    with open(os.path.join(path, "calibration", "calib_cam_to_pose.txt"), "r") as f:
        lines = f.readlines()
    cam_to_imu = {}
    for line in lines:
        line = line.split(":")
        mat = np.asarray([float(x) for x in line[1].strip().split(" ")]).reshape(3, 4)
        cam_to_imu[line[0].strip()] = cosy.np.Rigid(rotation=mat[:3, :3], translation=mat[:3, 3])
    velo_to_imu = cam_to_imu["image_00"] * velo_to_cam00

    with open(os.path.join(path, "calibration", "perspective.txt"), "r") as f:
        lines = f.readlines()
    intrinsics = {}
    for line in [l for l in lines if l.startswith("P_rect_")]:
        line = line.split(":")
        intrinsics[line[0][7:]] = np.asarray([float(x) for x in line[1].strip().split(" ")]).reshape(3, 4)[:3, :3]
    original_image_shapes = {}
    for line in [l for l in lines if l.startswith("S_rect_")]:
        line = line.split(":")
        original_image_shapes[line[0][7:]] = np.flip(np.asarray([float(x) for x in line[1].strip().split(" ")]).astype("int"), axis=0)

    sequences = sorted(read_trajectories_3d(path).items())
    for sequence, (frame_indices, timestamps_frames, trajectory) in tqdm.tqdm(sequences, desc="Kitti-360"):
        camera_params = []

        # Front perspective camera
        camera_params.append((
            "image_00",
            intrinsics["00"],
            velo_to_cam00,
            original_image_shapes["00"]
        ))

        # Fisheye cameras
        for camera_name in ["image_02", "image_03"]:
            with open(os.path.join(path, "calibration", f"{camera_name}-undistorted.yaml")) as f:
                config = yaml.safe_load(f)
            intr = config["intr"]
            with open(os.path.join(path, "calibration", f"{camera_name}.yaml")) as f:
                s = f.read()
                s = s[s.index("\n") + 1:]
                config = yaml.safe_load(s)
            original_image_shape = np.asarray([config["image_height"], config["image_width"]])

            velo_to_cam = cam_to_imu[camera_name].inverse() * velo_to_imu
            camera_params.append((
                camera_name,
                intr,
                velo_to_cam,
                original_image_shape,
            ))

        # IMU measurements
        oxts_path = os.path.join(path, "data_poses", sequence, "oxts", "data")
        oxts = []
        for file in sorted(os.listdir(oxts_path)):
            with open(os.path.join(oxts_path, file), "r") as f:
                oxts.append(f.read())
        oxts = [x.strip() for x in oxts]
        oxts = [x for x in oxts if len(x) > 0]
        oxts = np.asarray([[float(l) for l in x.split(" ")] for x in oxts])
        oxts_timestamps = read_timestamps(os.path.join(os.path.dirname(oxts_path), "timestamps.txt"))
        assert len(oxts) == len(oxts_timestamps)
        get_linear_acceleration = partial(
            cosy.np.lerp,
            xs=oxts_timestamps,
            ys=[o[14:17] for o in oxts],
        )
        get_angular_velocity = partial(
            cosy.np.lerp,
            xs=oxts_timestamps,
            ys=[o[20:23] for o in oxts],
        )

        scene = []
        scenes.append(scene)
        for index_in_scene, (frame_index, timestamp_frame, imu_to_world) in enumerate(zip(frame_indices, timestamps_frames, trajectory)):
            velo_file = os.path.join(path, "data_3d_raw", sequence, "velodyne_points", "data", f"{frame_index:010d}.bin")

            cameras = []
            for camera_name, intr, velo_to_cam, original_image_shape in camera_params:
                cameras.append(CameraId(
                    name=camera_name,
                    intr=intr,
                    resolution=original_image_shape,
                    ego_to_camera=cam_to_imu[camera_name].inverse() * imu_to_ego.inverse(),
                    image_file=os.path.join(path, "data_2d_raw", sequence, camera_name, "data_rect" if camera_name == "image_00" else "data_rgb", f"{frame_index:010d}.png"),
                ))

            lidars = []
            lidars.append(Bin32LidarId(
                name="velo",
                file=velo_file,
                map=np.asarray([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]),
                loaded_to_ego=imu_to_ego * velo_to_imu,
            ))

            scene.append(FrameId(
                dataset_name="kitti360",
                location="karlsruhe",
                scene_id=sequence,
                index_in_scene=index_in_scene,
                timestamp=int(timestamp_frame * 10 ** 6),
                ego_to_world=imu_to_world * imu_to_ego.inverse(),
                world_to_crs=world_to_epsg3857[sequence],
                crs=epsg3857,
                epsg4326_to_crs=epsg4326_to_epsg3857,
                cameras=cameras,
                lidars=lidars,
                imu=Imu(
                    angular_velocity=get_angular_velocity(timestamp_frame),
                    linear_acceleration=get_linear_acceleration(timestamp_frame),
                ),
            ))

    image_sizes = set([tuple(x for x in camera_id.load_resolution()) for camera_id in scenes[0][0].cameras])

    return {
        "karlsruhe": dataset.Dataset(
            name="kitti360",
            location="karlsruhe",
            scenes=scenes,
            image_sizes=np.asarray(list(image_sizes))
        )
    }
