#!/usr/bin/env python3

import argparse, tinylogdir, tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--tensorrt", action="store_true")
parser.add_argument("--stride", type=int, default=1)
parser.add_argument("--tiles", type=str, default="bingmaps")
parser.add_argument("--max-dist", type=float, default=None)
parser.add_argument("--pseudolabels", type=str, default=None)
parser.add_argument("--no-imu", action="store_true")
parser.add_argument("--predict-stride", type=int, default=1)

parser.add_argument("--no-map", action="store_true")
parser.add_argument("--no-images", action="store_true")
parser.add_argument("--no-video", action="store_true")

parser.add_argument("--scene-name", type=str, default=None)
parser.add_argument("--dataset-name", type=str, default="ford")

parser.add_argument("--pred-prob-power", type=float, default=0.3)
parser.add_argument("--update-covariance-multiplier", type=float, default=3.0)
parser.add_argument("--process-noise-multiplier", type=float, default=1.0)

parser.add_argument("--train-dir", type=str, required=True)

args = parser.parse_args()




log = tinylogdir.LogDir(args.output)

import georeg, cosy, sys, os, imageio, yaml, tfcv, math, cv2
import tiledwebmaps as twm
import georegdata as grd
import numpy as np
import tensorflow as tf
from functools import partial
import tinypl as pl


train_dir = args.train_dir

info = {}
info["train-dir"] = train_dir
info["tiles"] = args.tiles
info["pseudolabels"] = args.pseudolabels
info["predict-stride"] = args.predict_stride
info["max-dist"] = args.max_dist

with open(os.path.join(log.dir(), "config.yaml"), "w") as f:
    yaml.dump(info, f, default_flow_style=False)

checkpoint_epoch = None # "00006"

print("Loading datasets...")

maps_path = os.environ["AERIAL_DATA"]
wait_after_error = 5.0
retries = 100

tileloaders = {}

googlemaps = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "googlemaps"))
vars(googlemaps)["zoom"] = 20
vars(googlemaps)["name"] = "googlemaps"
tileloaders["googlemaps"] = googlemaps

bingmaps = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "bingmaps"))
vars(bingmaps)["zoom"] = 20
vars(bingmaps)["name"] = "bingmaps"
tileloaders["bingmaps"] = bingmaps

tileloader = tileloaders[args.tiles]

cache = True # "rebuild"

if args.dataset_name == "ford":
    ford_avdata = grd.ground.ford_avdata.load(os.path.join(os.environ["GROUND_DATA"], "ford-avdata"), cache=cache, valid_starts_only=True)
    datasets = [
        (ford_avdata["detroit"], [tileloader])
    ]
elif args.dataset_name == "kitti360":
    kitti360 = grd.ground.kitti360.load(os.path.join(os.environ["GROUND_DATA"], "kitti360"), cache=cache)
    datasets = [
        (kitti360["karlsruhe"], [tileloader])
    ]
else:
    assert False


assert len(tf.config.list_logical_devices("GPU")) == 1
def get_memory_usage():
    num_bytes = tf.config.experimental.get_memory_info("GPU:0")["peak"]
    tf.config.experimental.reset_memory_stats("GPU:0")
    return num_bytes
def format_bytes(size):
    size = float(size)
    power = 2 ** 10
    n = 0
    power_labels = {0 : "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
    while size >= power and n < len(power_labels) - 1:
        size /= power
        n += 1
    return f"{size:.1f}{power_labels[n]}"

with open(os.path.join(train_dir, "mlog", "config.yaml")) as f:
    config = yaml.safe_load(f)

angles_range = 10.0
angles_num = 21
angles = np.linspace(-1.0, 1.0, num=angles_num)
angles = np.sign(angles) * np.power(np.abs(angles), 2)
angles = angles * angles_range / 180.0 * math.pi

print("Building model...")
model = tf.keras.models.load_model(os.path.join(train_dir, "saved_model"), compile=False)
if not checkpoint_epoch is None:
    model.load_weights(os.path.join(train_dir, "checkpoints", f"model_checkpoint_epoch{checkpoint_epoch}.h5"))
preprocess_aerial = tfcv.model.pretrained.facebookresearch.preprocess
preprocess_ground = tfcv.model.pretrained.facebookresearch.preprocess

layer_names = [l.name for l in model.layers]
i = 3
while f"correlation-logits-{i + 1}" in layer_names:
    i += 1

model = tf.keras.Model(
    inputs=model.inputs,
    outputs=[
        model.get_layer(f"correlation-logits-{i}").output,
        model.get_layer(f"correlation-mask-{i}").output,
    ],
)

if args.tensorrt:
    print("Saving model...")
    tf.keras.models.save_model(model, log.dir("saved_model"), include_optimizer=False)

    print("Importing TRT...")
    from tensorflow.python.compiler.tensorrt import trt_convert as trt

    print("Creating TRT parameters...")
    conversion_params = trt.TrtConversionParams(precision_mode="FP16")

    print("Creating TRT converter...")
    converter = trt.TrtGraphConverterV2(
        input_saved_model_dir=log.dir("saved_model"),
        conversion_params=conversion_params,
    )
    print("Converting model...")
    converter.convert()
    print("Saving model...")
    converter.save(log.dir("saved_model_tensorrt"))

    print("Loading model...")
    model = tf.keras.models.load_model(log.dir("saved_model_tensorrt"), compile=False)

final_meters_per_pixel = config["model"]["final-meters-per-pixel"]
aerial_final_shape = config["model"]["final-aerial-shape"]
bev_final_shape = config["model"]["final-bev-shape"]
ground_attn_strides = config["model"]["ground-attn-strides"]
aerial_attn_strides = config["model"]["aerial-attn-strides"]
aerial_stride = config["model"]["aerial-stride"]

max_cameras = 9

model_constants = georeg.model.base.ModelConstants(
    bev_final_shape=bev_final_shape,
    aerial_final_shape=aerial_final_shape,
    final_meters_per_pixel=final_meters_per_pixel,
    ground_attn_strides=ground_attn_strides,
    aerial_attn_strides=aerial_attn_strides,
    aerial_stride=aerial_stride,
    max_cameras=max_cameras,
)

if not args.max_dist is None:
    multiplier = 8
    corr_shape = (int(2 * args.max_dist / final_meters_per_pixel) + multiplier - 1) // multiplier * multiplier
    corr_shape = np.asarray([corr_shape, corr_shape])
    print(f"Using corr_shape {corr_shape}")
else:
    corr_shape = None

predictor = georeg.track.Predictor(
    model=model,
    model_constants=model_constants,
    angles=angles,
    preprocess_aerial=preprocess_aerial,
    preprocess_ground=preprocess_ground,
    config=config,
    corr_shape=corr_shape,
)

def compute_groundtruth(scene):
    import scipy.ndimage

    world_to_epsg3857 = cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(scene[0].latlon)
    def epsg4326_to_transform(latlon, bearing):
        return cosy.np.Rigid(
            translation=world_to_epsg3857.inverse()(predictor.epsg4326_to_epsg3857(latlon)).astype("float32"),
            rotation=cosy.np.angle_to_rotation_matrix(predictor.epsg4326_to_epsg3857.transform_angle(math.radians(bearing))).astype("float32"),
        )

    def default_diff(x2, x1):
        x2 = np.asarray(x2)
        x1 = np.asarray(x1)
        if x2.shape != x1.shape:
            raise ValueError(f"Tensors must have the same shape, but got {x2.shape} and {x1.shape}")
        if len(x1.shape) == 1:
            d = x2 - x1
        elif len(x1.shape) == 2:
            d = np.linalg.norm(x1 - x2, axis=1)
        else:
            raise ValueError(f"Tensors must have rank 1 or 2, but got shape {x1.shape}")
        return d

    def diff(xs, ys, x_diff=default_diff, y_diff=default_diff):
        if len(xs) != len(ys):
            raise ValueError(f"xs and ys must have the same number of elements, but got {len(xs)} in xs and {len(ys)} in ys")
        if len(xs) <= 1:
            raise ValueError(f"Must have at least 2 elements, but got {len(xs)}")
        ds = np.concatenate([
            np.asarray(y_diff(ys[1:2], ys[:1])) / np.asarray(x_diff(xs[1:2], xs[:1])),
            np.asarray(y_diff(ys[2:], ys[:-2])) / np.asarray(x_diff(xs[2:], xs[:-2])),
            np.asarray(y_diff(ys[-1:], ys[-2:-1])) / np.asarray(x_diff(xs[-1:], xs[-2:-1])),
        ], axis=0)
        assert len(ds) == len(ys)
        return ds

    def normalize_angle(a, positive=False):
        a = np.fmod(a, 2 * math.pi)
        if positive:
            np.where(a < 0.0, a + 2 * math.pi, a)
        else:
            a = np.where(a < -math.pi, a + 2 * math.pi, a)
            a = np.where(a > +math.pi, a - 2 * math.pi, a)
        return a



    frames_to_world = []
    timestamps = []
    for ground_frame_id in scene:
        frames_to_world.append(epsg4326_to_transform(ground_frame_id.latlon, ground_frame_id.bearing))
        timestamps.append(ground_frame_id.timestamp * 1e-6)
    timestamps = np.asarray(timestamps)
    assert np.all(timestamps[:-1] < timestamps[1:])

    p_diff1 = diff(
        xs=timestamps,
        ys=frames_to_world,
        y_diff=lambda y2, y1: [(a.inverse() * b).translation[0] for b, a in zip(y2, y1)],
    )
    a_diff1 = diff(
        xs=timestamps,
        ys=[cosy.np.rotation_matrix_to_angle(f.rotation) for f in frames_to_world],
        y_diff=lambda y2, y1: [normalize_angle(b - a) for b, a in zip(y2, y1)],
    )

    p_diff1 = scipy.ndimage.gaussian_filter1d(p_diff1, sigma=2.5)
    a_diff1 = scipy.ndimage.gaussian_filter1d(a_diff1, sigma=2.5)

    p_diff2 = diff(
        xs=timestamps,
        ys=p_diff1,
    )

    p_diff2 = scipy.ndimage.gaussian_filter1d(p_diff2, sigma=0.8)

    velocities = p_diff1
    yawrates = a_diff1
    accelerations = p_diff2

    return velocities, yawrates, accelerations


if not args.pseudolabels is None:
    with open(args.pseudolabels, "r") as f:
        lines = f.readlines()
        lines = [l.strip() for l in lines]
        lines = [l.split(",") for l in lines if len(l) > 0]
    pseudolabel_ego_to_world = {(line[0], line[1]): cosy.np.Rigid.from_matrix(np.asarray([float(x) for x in line[2:2 + 4 * 4]]).reshape((4, 4))) for line in lines}

def augment(frame_id, tileloader):
    params = frame_id.get_params()
    if not args.pseudolabels is None:
        params["ego_to_world"] = pseudolabel_ego_to_world[(frame_id.name, f"{tileloader.name}-zoom{tileloader.zoom}")]
    return grd.ground.FrameId(**params)

for dataset, tileloaders in datasets:
    for tileloader in tileloaders:
        tileloader2 = twm.LRUCached(tileloader, 100)
        vars(tileloader2)["name"] = tileloader.name
        vars(tileloader2)["zoom"] = tileloader.zoom
        tileloader = tileloader2
        for scene in dataset.scenes:
            scene_name = scene[0].scene_id
            if not args.scene_name is None and scene_name != args.scene_name:
                continue

            scene = [augment(f, tileloader) for f in scene]
            path = log.dir(f"{dataset.name}-{scene[0].scene_id}-{tileloader.name}")
            print(f"Tracking scene {scene[0].scene_id} in dataset {dataset.name} with tileloader {tileloader.name}")

            scene = scene[::args.stride]

            gt_velocities, gt_yawrates, gt_accelerations = compute_groundtruth(scene)

            print(f"Initial velocity={gt_velocities[0]}m/s (ignoring yawrate={math.degrees(gt_yawrates[0])}deg/s acceleration={gt_accelerations[0]}m/s^2)")

            tracker = georeg.track.KalmanTracker(
                predictor=predictor,
                tileloader=tileloader,
                filter=georeg.track.kalman.CTRA,
                timestamp=scene[0].timestamp,
                latlon=scene[0].latlon,
                bearing=scene[0].bearing,
                recalibrate=True,
                process_noise_multiplier=args.process_noise_multiplier,
                position={"std": 5.0},
                heading={"std": 0.2},
                velocity={"mean": gt_velocities[0], "std": 3.0},
                yawrate={"mean": 0.0, "std": 1.0},
                acceleration={"mean": 0.0, "std": 10.0},
            )

            print(f"latlon={tracker.latlon} curr_bearing={tracker.bearing}")

            frames_path = os.path.join(path, "frames")
            os.makedirs(frames_path)

            video = None

            latlons = []
            bearings = []

            def save_map():
                image = grd.visualize.draw_trajectories(
                    [latlons],
                    tileloader,
                    zoom=tileloader.zoom,
                    bearings=[bearings],
                    tile_padding=2,
                    downsample=1,
                    bearing_stride=max(5 // args.stride, 1),
                    bearing_length=0.5,
                    verbose=False,
                    sync_tile_loader=False,
                )
                imageio.imwrite(os.path.join(path, "trajectory.jpg"), image)


            scene = scene[1:]
            stream = iter(scene)
            stream, iterations_order = pl.order.save(stream)
            stream = pl.sync(stream) # Concurrent processing starts after this point

            def load_frame(frame_id):
                frame = frame_id.load()
                return grd.ground.AlignedFrameId(frame_id, frame_id.latlon, frame_id.bearing, model_constants.meters_per_pixel[-1]).load()
            stream = pl.map(load_frame, stream)
            stream = pl.queued(stream, workers=8, maxsize=4)
            stream = pl.order.load(stream, iterations_order)

            for frame_index, ground_frame in enumerate(tqdm.tqdm(stream, total=len(scene))):
                print(dataset.name, scene_name)
                use_prediction = frame_index % args.predict_stride == 0
                tracking_state = tracker(
                    ground_frame,
                    update_covariance_multiplier=[1.0, args.update_covariance_multiplier, args.update_covariance_multiplier],
                    weak_prior_covariance_multiplier=[1.0, 1.0, 1.0],
                    pred_prob_power=args.pred_prob_power,
                    use_imu=not args.no_imu,
                    use_prediction=use_prediction,
                )

                print(f"latlon={tracker.latlon} curr_bearing={tracker.bearing}")
                print(f"linear_acceleration={ground_frame.imu.linear_acceleration} angular_velocity={ground_frame.imu.angular_velocity}")

                print("Saving...", end="")
                sys.stdout.flush()
                if not args.no_images:
                    image = georeg.track.visualize.draw_aerial_image(
                        tracking_state,
                        tracker,
                        points_alpha=0.0,
                    )
                    imageio.imwrite(os.path.join(frames_path, f"frame{frame_index}-0-color.jpg"), image)
                    image = georeg.track.visualize.draw_aerial_image(
                        tracking_state,
                        tracker,
                        points_alpha=1.0,
                    )
                    imageio.imwrite(os.path.join(frames_path, f"frame{frame_index}-1-lidar.jpg"), image)
                    for i, (name, probability_distribution) in enumerate(tracking_state.probability_distributions.items()):
                        image = georeg.track.visualize.draw_aerial_image(
                            tracking_state,
                            tracker,
                            probs=np.power(probability_distribution.probs, 1.0 if "calib" in name else 1.0),
                            points_alpha=0.0,
                            recalibrate=True,
                        )
                        imageio.imwrite(os.path.join(frames_path, f"frame{frame_index}-{i + 2}-{name}.jpg"), image)


                if not args.no_video:
                    image = georeg.track.visualize.draw_aerial_image(
                        tracking_state,
                        tracker,
                        # probs=tracking_state.probability_distributions["posterior"].probs,
                        points_alpha=0.7,
                    )[::-1, ::-1]

                    front_cams = []
                    for camera_id in ground_frame.frame_id.base_frame_id.get_params()["cameras"]:
                        cam_yaw = camera_id.ego_to_camera.inverse().rotation @ np.asarray([0.0, 0.0, 1.0])
                        cam_yaw = cosy.np.angle(np.asarray([1.0, 0.0]), cam_yaw[:2])
                        if abs(math.degrees(cam_yaw)) < 10.0:
                            front_cams.append(camera_id)
                    assert len(front_cams) > 0
                    def fov(camera_id):
                        # resolution = 0.5 * (camera_id.resolution[0] + camera_id.resolution[1])
                        # focal_length = 0.5 * (camera_id.intr[0, 0] + camera_id.intr[1, 1])
                        resolution = camera_id.resolution[1]
                        focal_length = camera_id.intr[0, 0]
                        return 0.5 * math.atan(0.5 * resolution / focal_length)
                    front_cams = sorted(front_cams, key=fov)
                    cam_name = front_cams[-1].name

                    front_cam_image = [c for c in tracking_state.frame.ground_frame.cameras if cam_name in c.name][0].image

                    dest_width = image.shape[1]
                    dest_height = front_cam_image.shape[0] * dest_width / front_cam_image.shape[1]
                    shape = np.asarray([dest_height, dest_width])

                    front_cam_image = tfcv.image.resize_to(shape)((front_cam_image, "color"))
                    image = np.concatenate([image, front_cam_image], axis=0)

                    if video is None:
                        video = cv2.VideoWriter(
                            os.path.join(path, f"{scene[0].scene_id}.avi"),
                            cv2.VideoWriter_fourcc(*"mp4v"),
                            1.0 / ((scene[1].timestamp - scene[0].timestamp) * 1e-6), # fps
                            (image.shape[1], image.shape[0]),
                        )
                    video.write(image[:, :, ::-1])

                latlons.append(tracker.latlon)
                bearings.append(tracker.bearing)
                if frame_index % 1000 == 0:
                    if not args.no_map:
                        save_map()
                    np.savez(os.path.join(path, "trajectory.npz"), latlons=latlons, bearings=bearings)
                print(" done")

                print(f"Mem: {format_bytes(get_memory_usage())}")
            if not video is None:
                video.release()

            if not args.no_map:
                save_map()
            np.savez(os.path.join(path, "trajectory.npz"), latlons=latlons, bearings=bearings)
