#!/usr/bin/env python3

import argparse, os, shutil, yaml, cv2, cv2.omnidir
from tqdm import tqdm
import numpy as np
from functools import partial

parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--width", type=int, default=-1)
parser.add_argument("--height", type=int, default=-1)
parser.add_argument("--undistort-focal-length-factor", type=float, default=0.25)
args = parser.parse_args()
assert (args.width > 0) == (args.height > 0)

import georeg

files = [
    "calibration.zip",
    "data_poses.zip",
    "download_2d_perspective.zip",
    "download_2d_fisheye.zip",
    "download_3d_velodyne.zip",
]

for file in files:
    if not os.path.isfile(os.path.join(args.path, file)):
        print("Please download from the KITTI-360 website and place into the folder specified by --path:")
        print("    Perspective Images for Train & Val")
        print("    Fisheye Images")
        print("    Raw Velodyne Scans")
        print("    Calibrations")
        print("    Vechicle Poses")
        print("    OXTS Measurements") # TODO: extract this
        sys.exit(-1)

for file in files:
    georeg.data.prepare.extract(os.path.join(args.path, file), args.path)

georeg.data.prepare.run(f"cd {args.path} && ./download_2d_perspective.sh")
georeg.data.prepare.run(f"cd {args.path} && ./download_2d_fisheye.sh")
georeg.data.prepare.run(f"cd {args.path} && ./download_3d_velodyne.sh")

print("Changing directory structure")
path = os.path.join(args.path, "KITTI-360")
for f in os.listdir(path):
    shutil.move(os.path.join(path, f), os.path.join(args.path, f))
shutil.rmtree(path)

def undistort(image, file, intr, intr_new, distortion_parameters, xi):
    image = cv2.omnidir.undistortImage(image, intr, distortion_parameters, np.asarray(xi), cv2.omnidir.RECTIFY_PERSPECTIVE, Knew=intr_new)
    return image

if args.height > 0:
    camera_params = {}
    for camera in ["image_02", "image_03"]:
        with open(os.path.join(args.path, "calibration", f"{camera}.yaml")) as f:
            s = f.read()
            s = s[s.index("\n") + 1:]
            config = yaml.safe_load(s)
        intr = np.asarray([
            [config["projection_parameters"]["gamma1"], 0, config["projection_parameters"]["u0"]],
            [0, config["projection_parameters"]["gamma2"], config["projection_parameters"]["v0"]],
            [0, 0, 1]
        ], dtype="float32")
        distortion_parameters = np.asarray([
            config["distortion_parameters"]["k1"],
            config["distortion_parameters"]["k2"],
            config["distortion_parameters"]["p1"],
            config["distortion_parameters"]["p2"],
        ], dtype="float32")
        xi = config["mirror_parameters"]["xi"]
        intr_new =  np.asarray([
            [intr[0, 0] * args.undistort_focal_length_factor, 0, intr[0, 2]],
            [0, intr[1, 1] * args.undistort_focal_length_factor, intr[1, 2]],
            [0, 0, 1],
        ])
        with open(os.path.join(args.path, "calibration", f"{camera}-undistorted.yaml"), "w") as f:
            yaml.dump({"intr": intr_new.tolist()}, f, default_flow_style=False)
        camera_params[camera] = (intr, intr_new, distortion_parameters, xi)

    for scene in sorted(os.listdir(os.path.join(args.path, "data_2d_raw"))):
        scene_path = os.path.join(args.path, "data_2d_raw", scene)
        for camera in ["image_00", "image_01"]:
            georeg.data.prepare.resize(os.path.join(scene_path, camera), (args.height, args.width))

        for camera in ["image_02", "image_03"]:
            intr, intr_new, distortion_parameters, xi = camera_params[camera]
            georeg.data.prepare.resize(os.path.join(scene_path, camera), (args.height, args.width), preprocess=partial(undistort, intr=intr, intr_new=intr_new, distortion_parameters=distortion_parameters, xi=xi))
