#!/usr/bin/env python3

import argparse, os, docker, shutil, rosbag, tqdm, sys, yaml, threading
import open3d as o3d
import numpy as np
from pyquaternion import Quaternion

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True, help="Directory where dataset will be stored")
parser.add_argument("--image", type=str, default="ford-avdata-convert", help="Name of docker image to be built")
parser.add_argument("--container", type=str, default="ford-avdata-convert", help="Name of docker container to be run")
parser.add_argument("--timeout", type=float, default=30.0, help="Time to wait after last converted message before stopping the docker image")
parser.add_argument("--rate", type=float, default=1.0, help="Rate at which to playback the rosbag file")
parser.add_argument("--queue_size", type=int, default=99999999, help="Size of subscriber and publisher queues of all ros nodes")
parser.add_argument("--width", type=int, default=-1)
parser.add_argument("--height", type=int, default=-1)
args = parser.parse_args()
assert (args.width > 0) == (args.height > 0)

if not os.path.isdir(args.path):
    os.makedirs(args.path)

base_url = "https://ford-multi-av-seasonal.s3-us-west-2.amazonaws.com"
cameras = ["RR", "RL", "SR", "SL", "FR", "FL", "Center"]
series = [
    ("2017-10-26", [("V2", 6, cameras)], 6),
    ("2017-08-04", [("V2", 6, cameras), ("V3", 6, set(cameras).difference(set(["RL"])))], 6),
]
cars = ["V1", "V2", "V3"]

################################## Download ##################################
for car in cars:
    file = os.path.join(args.path, f"Calibration-{car}.tar.gz")
    georeg.data.prepare.download(f"{base_url}/Calibration/{os.path.basename(file)}", file)
    georeg.data.prepare.extract(file, args.path)

for date, drives, maps_num in series:
    # for map_index in range(1, maps_num + 1):
    #     download(f"{base_url}/{date}/Maps/{date}-Map{map_index}.tar.gz")
    for car, log_num, cameras in drives:
        for log_index in range(1, log_num + 1):
            for camera in cameras:
                file = os.path.join(args.path, f"{date}-{car}-Log{log_index}-{camera}.tar.gz")
                resize = False
                if not os.path.isdir(file[:-7]):
                    georeg.data.prepare.download(f"{base_url}/{date}/{car}/Log{log_index}/{os.path.basename(file)}", file)
                    resize = True # TODO: remove
                if os.path.isfile(file):
                    georeg.data.prepare.extract(file, file[:-7])
                    resize = True # TODO: remove
                if resize and args.width > 0:
                    georeg.data.prepare.resize(file[:-7], (args.height, args.width))
            file = os.path.join(args.path, f"{date}-{car}-Log{log_index}.bag")
            georeg.data.prepare.download(f"{base_url}/{date}/{car}/Log{log_index}/{os.path.basename(file)}", file)


for date, drives, maps_num in series:
    for car, log_num, cameras in drives:
        for log_index in range(1, log_num + 1):
            scene_path = os.path.join(args.path, f"{date}-{car}-Log{log_index}")
            if not os.path.isdir(scene_path):
                os.makedirs(scene_path)
            for camera in cameras:
                camera_path = os.path.join(scene_path, "camera", camera)
                src_path = os.path.join(args.path, f"{date}-{car}-Log{log_index}-{camera}")
                if os.path.isdir(src_path):
                    shutil.move(src_path, camera_path)




def rospose_to_transform(pose):
    return georeg.transform.np.Rigid(
        translation=np.asarray([pose.position.x, pose.position.y, pose.position.z]),
        rotation=Quaternion(np.asarray([pose.orientation.w, pose.orientation.x, pose.orientation.y, pose.orientation.z])).rotation_matrix,
    )

def rostransform_to_transform(pose):
    return georeg.transform.np.Rigid(
        translation=np.asarray([pose.translation.x, pose.translation.y, pose.translation.z]),
        rotation=Quaternion(np.asarray([pose.rotation.w, pose.rotation.x, pose.rotation.y, pose.rotation.z])).rotation_matrix,
    )

def quaternion_to_rotationmatrix(q):
    return Quaternion(np.asarray([q.w, q.x, q.y, q.z])).rotation_matrix

def stamp_to_us(stamp):
    return int(stamp.secs * 10 ** 6 + stamp.nsecs / (10 ** 3))

print("Saving transforms...")
for bagfile in tqdm.tqdm([os.path.join(args.path, f) for f in sorted(os.listdir(args.path)) if f.endswith(".bag")]):
    pose_ground_truth = []
    tf = []
    gps = []
    gps_time = []
    imu = []
    with rosbag.Bag(bagfile, "r") as bag:
        for (topic, msg, ts) in bag.read_messages(topics="/pose_ground_truth"):
            timestamp = stamp_to_us(msg.header.stamp)
            transform = rospose_to_transform(msg.pose)
            pose_ground_truth.append((timestamp, transform))
        for (topic, msg, ts) in bag.read_messages(topics="/tf"):
            for msg in msg.transforms:
                timestamp = stamp_to_us(msg.header.stamp)
                transform = rostransform_to_transform(msg.transform)
                parent_frame = msg.header.frame_id
                child_frame = msg.child_frame_id
                tf.append((timestamp, transform))
        for (topic, msg, ts) in bag.read_messages(topics="/gps"):
            gps.append((stamp_to_us(msg.header.stamp), (msg.latitude, msg.longitude)))
        for (topic, msg, ts) in bag.read_messages(topics="/gps_time"):
            gps_time.append((stamp_to_us(msg.header.stamp), stamp_to_us(msg.time_ref)))
        for (topic, msg, ts) in bag.read_messages(topics="/imu"):
            imu.append((
                stamp_to_us(msg.header.stamp),
                quaternion_to_rotationmatrix(msg.orientation),
                np.asarray(msg.orientation_covariance).reshape([3, 3]),
                np.asarray([msg.angular_velocity.x, msg.angular_velocity.y, msg.angular_velocity.z]),
                np.asarray(msg.angular_velocity_covariance).reshape([3, 3]),
                np.asarray([msg.linear_acceleration.x, msg.linear_acceleration.y, msg.linear_acceleration.z]),
                np.asarray(msg.linear_acceleration_covariance).reshape([3, 3]),
            ))

    imu_timestamp_to_data = {x[0]: x for x in imu}
    imu = [imu_timestamp_to_data[t] for t in sorted(list(imu_timestamp_to_data.keys()))]
    assert np.all(imu[:-1][0] < imu[1:][0])
    assert len(gps) == len(gps_time)
    gps = sorted(gps, key=lambda x: x[0])
    gps_time = sorted(gps_time, key=lambda x: x[0])
    # gps = list(zip([x[1] for x in gps_time], [x[1] for x in gps]))
    gps = list(zip([x[0] for x in gps], [x[1] for x in gps]))

    np.savez_compressed(
        os.path.join(bagfile[:-4], "pose_ground_truth.npz"),
        timestamps=np.asarray([p[0] for p in pose_ground_truth]),
        transforms=np.asarray([p[1].to_matrix() for p in pose_ground_truth]),
    )
    np.savez_compressed(
        os.path.join(bagfile[:-4], "tf.npz"),
        timestamps=np.asarray([p[0] for p in tf]),
        transforms=np.asarray([p[1].to_matrix() for p in tf]),
    )
    np.savez_compressed(
        os.path.join(bagfile[:-4], "gps.npz"),
        timestamps=np.asarray([p[0] for p in gps]),
        latlons=np.asarray([p[1] for p in gps]),
    )
    np.savez_compressed(
        os.path.join(bagfile[:-4], "imu.npz"),
        timestamps=np.asarray([p[0] for p in imu]),
        orientation=np.asarray([p[1] for p in imu]),
        orientation_covariance=np.asarray([p[2] for p in imu]),
        angular_velocity=np.asarray([p[3] for p in imu]),
        angular_velocity_covariance=np.asarray([p[4] for p in imu]),
        linear_acceleration=np.asarray([p[5] for p in imu]),
        linear_acceleration_covariance=np.asarray([p[6] for p in imu]),
    )




# #################### Convert bagfiles to pointcloud files ####################
colors = ["red", "blue", "green", "yellow"]
root_path = os.path.dirname(os.path.abspath(sys.argv[0]))

client = docker.from_env()

print(f"Building docker image {args.image}")
image, docker_build_logs = client.images.build(path=os.path.join(root_path, "ford_avdata_util"), tag=args.image)

bagfiles = sorted([os.path.join(args.path, f) for f in os.listdir(args.path) if f.endswith(".bag")])
print(f"Running for {len(bagfiles)} rosbag files")

for bagfile in bagfiles:
    print(f"##################### {os.path.basename(bagfile)} #####################")
    docker_output_path = os.path.join("/ford-avdata-data", os.path.basename(bagfile)[:-4])
    docker_bagfile = os.path.join("/ford-avdata-data", os.path.basename(bagfile))
    host_output_path = os.path.join(args.path, os.path.basename(bagfile)[:-4])
    try:
        print(f"Starting container {args.image}")
        for color in colors:
            color_path = os.path.join(host_output_path, "lidar", color)
            if not os.path.isdir(color_path):
                os.makedirs(color_path)
        volumes = {args.path: {"bind": "/ford-avdata-data", "mode": "rw"}}
        container = client.containers.run(image=image, name=args.container, command=f"/bin/bash", volumes=volumes, remove=True, stdin_open=True, detach=True, tty=True)

        print("Checking rosbag contents")
        cmd = f"rosbag info {docker_bagfile} --yaml --key=topics"
        print("> " + cmd)
        exit_code, logs = container.exec_run(f"bash -c 'source /ford-avdata-code/entrypoint.sh && {cmd}'")
        logs = logs.decode()
        if exit_code != 0:
            print(f"Failed command with exit code {exit_code}")
            print(logs)
            container.stop()
            sys.exit(-1)
        message_nums = {n["topic"]: int(n["messages"]) for n in yaml.safe_load(logs) if n["topic"].startswith("/lidar") or n["topic"] == "/pose_ground_truth"}
        print(f"Expecting messages: {message_nums}")

        print("Converting point clouds from rosbag to pcd files")
        cmd = f"roslaunch /ford-avdata-code/convert.launch bag:={docker_bagfile} output_path:={docker_output_path} rate:={args.rate} queue_size:={args.queue_size}"
        print("> " + cmd)
        _, logs = container.exec_run(f"bash -c 'source /ford-avdata-code/entrypoint.sh && {cmd}'", stream=True)

        # Wait for conversion to finish, i.e. when no message is received from docker image for specified timeout
        exit_event = threading.Event()
        def read():
            for line in logs:
                exit_event.set()
        read_thread = threading.Thread(target=read)
        read_thread.start()
        while exit_event.wait(args.timeout):
            exit_event.clear()
        container.stop()
        read_thread.join()
    except KeyboardInterrupt:
        print("Stopping...")
        container.stop()
        sys.exit(-1)

    # Check if any messages are missing
    fail = False
    for color in colors:
        color_path = os.path.join(host_output_path, "lidar", color)
        expected_messages = message_nums[f"/lidar_{color}_scan"]
        got_messages = len(os.listdir(color_path))
        if expected_messages != got_messages:
            print(f"Expected {expected_messages} messages for color {color}, but found {got_messages} saved files")
            fail = True
    if fail:
        print("Rosbag conversion failed. Try a smaller --rate or larger --queue_size")
        sys.exit(-1)

    os.remove(bagfile)

    print("Converting point clouds from pcd files to npz files")
    paths = []
    for color in colors:
        color_path = os.path.join(host_output_path, "lidar", color)
        paths = paths + [os.path.join(color_path, f) for f in os.listdir(color_path)]
    for pcd_file in tqdm.tqdm(paths):
        pcd = o3d.io.read_point_cloud(pcd_file)
        points = np.asarray(pcd.points)
        np.savez_compressed(pcd_file[:-4] + ".npz", points.astype("float32"))
        os.remove(pcd_file)
