#!/usr/bin/env python3

import argparse, os, tqdm, shutil
from pyarrow import feather
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--width", type=int, default=-1)
parser.add_argument("--height", type=int, default=-1)
args = parser.parse_args()
assert (args.width > 0) == (args.height > 0)

import georegdata as grd

if shutil.which("s5cmd") is None:
    print("Please install s5cmd (see argoverse-v2 github repository)")
    sys.exit(-1)


sensor_dataset_path = os.path.join(args.path, "sensor-dataset")
if not os.path.isdir(sensor_dataset_path):
    os.makedirs(sensor_dataset_path)
grd.prepare.run(f"s5cmd --no-sign-request cp s3://argoai-argoverse/av2/sensor/* {sensor_dataset_path}")

mapchange_dataset_path = os.path.join(args.path, "mapchange-dataset")
if not os.path.isdir(mapchange_dataset_path):
    os.makedirs(mapchange_dataset_path)
grd.prepare.run(f"s5cmd --no-sign-request cp s3://argoai-argoverse/av2/tars/tbv/*.tar.gz {mapchange_dataset_path}")
for f in [os.path.join(mapchange_dataset_path, f) for f in os.listdir(mapchange_dataset_path) if f.endswith(".tar.gz")]:
    grd.prepare.run(f"tar -xvzf {f} --directory {mapchange_dataset_path}")
    os.remove(f)

if args.height > 0:
    grd.prepare.resize(args.path, (args.height, args.width))


scene_paths = []
for split in os.listdir(sensor_dataset_path):
    split_path = os.path.join(sensor_dataset_path, split)
    if os.path.isdir(split_path):
        for scene in os.listdir(split_path):
            scene_path = os.path.join(split_path, scene)
            if os.path.isdir(scene_path):
                scene_paths.append(scene_path)
for scene in os.listdir(mapchange_dataset_path):
    scene_path = os.path.join(mapchange_dataset_path, scene)
    if os.path.isdir(scene_path):
        scene_paths.append(scene_path)
tasks = []
for scene_path in scene_paths:
    lidar_path = os.path.join(scene_path, "sensors", "lidar")
    for lidar_file in [os.path.join(lidar_path, f) for f in os.listdir(lidar_path) if f.endswith(".feather")]:
        tasks.append(lidar_file)
for lidar_file in tqdm.tqdm(tasks):
    data = feather.read_feather(lidar_file, columns=None)
    points = data[list("xyz")].to_numpy().astype("float64")
    np.savez_compressed(lidar_file[:-7] + "npz", points)
    os.remove(lidar_file)
