#!/usr/bin/env python3

import argparse, tinylogdir, tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--offsets", type=str, default=None)
parser.add_argument("--rotation", type=float, required=True)
parser.add_argument("--translation", type=float, required=True)
parser.add_argument("--train-dir", type=str, required=True)
args = parser.parse_args()




log = tinylogdir.LogDir(args.output)

import georeg, cosy, sys, os, imageio, yaml, tfcv, math, cv2, pickle
import tiledwebmaps as twm
import georegdata as grd
import numpy as np
import tensorflow as tf
from functools import partial
import tinypl as pl

train_dir = args.train_dir

print("Loading datasets...")

maps_path = os.environ["AERIAL_DATA"]
wait_after_error = 5.0
retries = 100

googlemaps = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "googlemaps"))
vars(googlemaps)["zoom"] = 20
vars(googlemaps)["name"] = "googlemaps"

ford_avdata = grd.ground.ford_avdata.load(os.path.join(os.environ["GROUND_DATA"], "ford-avdata"), cache=True)
ford_avdata = {s[0].scene_id: s for s in ford_avdata["detroit"].scenes}

print("Loading pseudolabels...", end="")
sys.stdout.flush()

with open(os.environ["PSEUDOLABELS"], "r") as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]
    lines = [l.split(",") for l in lines if len(l) > 0]
pseudolabel_ego_to_world = {(line[0], line[1]): cosy.np.Rigid.from_matrix(np.asarray([float(x) for x in line[2:2 + 4 * 4]]).reshape((4, 4))) for line in lines}
def augment(frame_id):
    params = frame_id.get_params()
    params["ego_to_world"] = pseudolabel_ego_to_world[(frame_id.name, "googlemaps-zoom20")]
    return grd.ground.FrameId(**params)
print(" done")

assert len(tf.config.list_logical_devices("GPU")) == 1
def get_memory_usage():
    num_bytes = tf.config.experimental.get_memory_info("GPU:0")["peak"]
    tf.config.experimental.reset_memory_stats("GPU:0")
    return num_bytes
def format_bytes(size):
    size = float(size)
    power = 2 ** 10
    n = 0
    power_labels = {0 : "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
    while size >= power and n < len(power_labels) - 1:
        size /= power
        n += 1
    return f"{size:.1f}{power_labels[n]}"

with open(os.path.join(train_dir, "mlog", "config.yaml")) as f:
    config = yaml.safe_load(f)

if args.rotation >= 360:
    print("No rotation information")
    angles = np.linspace(0.0, 359.0, num=360) / 180.0 * math.pi
else:
    angles_range = args.rotation
    angles_num = 2 * int(angles_range) + 1
    angles = np.linspace(-1.0, 1.0, num=angles_num)
    # angles = np.sign(angles) * np.power(np.abs(angles), 2)
    angles = angles * angles_range / 180.0 * math.pi

print("Building model...")
model = tf.keras.models.load_model(os.path.join(train_dir, "saved_model"), compile=False)
preprocess_aerial = tfcv.model.pretrained.facebookresearch.preprocess
preprocess_ground = tfcv.model.pretrained.facebookresearch.preprocess

layer_names = [l.name for l in model.layers]
i = 3
while f"correlation-logits-{i + 1}" in layer_names:
    i += 1

model = tf.keras.Model(
    inputs=model.inputs,
    outputs=[
        model.get_layer(f"correlation-logits-{i}").output,
        model.get_layer(f"correlation-mask-{i}").output,
    ],
)

final_meters_per_pixel = config["model"]["final-meters-per-pixel"]
aerial_final_shape = config["model"]["final-aerial-shape"]
bev_final_shape = config["model"]["final-bev-shape"]
ground_attn_strides = config["model"]["ground-attn-strides"]
aerial_attn_strides = config["model"]["aerial-attn-strides"]
aerial_stride = config["model"]["aerial-stride"]

max_cameras = 9

model_constants = georeg.model.base.ModelConstants(
    bev_final_shape=bev_final_shape,
    aerial_final_shape=aerial_final_shape,
    final_meters_per_pixel=final_meters_per_pixel,
    ground_attn_strides=ground_attn_strides,
    aerial_attn_strides=aerial_attn_strides,
    aerial_stride=aerial_stride,
    max_cameras=max_cameras,
)

multiplier = 8
corr_shape = (int(2 * (args.translation + 1) / final_meters_per_pixel) + multiplier - 1) // multiplier * multiplier
corr_shape = np.asarray([corr_shape, corr_shape])
print(f"Using corr_shape {corr_shape}")

predictor = georeg.track.Predictor(
    model=model,
    model_constants=model_constants,
    angles=angles,
    preprocess_aerial=preprocess_aerial,
    preprocess_ground=preprocess_ground,
    config=config,
    corr_shape=corr_shape,
)


scenes = [
    ('2017-08-04-V2-Log1', os.path.join(os.environ["SHIETAL"], "grd_sat_quaternion_latlon_test_log1.txt")),
    ('2017-08-04-V2-Log2', os.path.join(os.environ["SHIETAL"], "grd_sat_quaternion_latlon_test_log2.txt")),
    ('2017-08-04-V2-Log3', None),
    ('2017-08-04-V2-Log4', None),
    ('2017-08-04-V2-Log5', None),
    ('2017-08-04-V2-Log6', None),
]
result = ""
for scene_name, gt_file in scenes:
    print(scene_name)
    scene = ford_avdata[scene_name]
    scene = [augment(f) for f in scene]

    if not args.offsets is None and args.offsets == "shietal":
        if gt_file is None:
            continue
        with open(gt_file, "r") as f:
            lines = f.readlines()
        lines = [l.strip() for l in lines]
        lines = [l for l in lines if len(l) > 0]
        lines = sorted(lines, key=lambda x: int(x.strip().split(" ")[0][:-4]))

        ha_ground_latlons = []
        # ha_aerial_latlons = []
        ha_translation_offsets = []
        ha_timestamps = []
        ha_bearing_offsets = []
        for line in lines:
            grd_name, q0, q1, q2, q3, g_lat, g_lon, s_lat, s_lon, gt_shift_u, gt_shift_v, theta = line.split(" ")
            # print(args.translation * np.linalg.norm([gt_shift_u, gt_shift_v]), cosy.np.geo.distance(np.asarray([float(g_lat), float(g_lon)]), np.asarray([float(s_lat), float(s_lon)])))
            ha_timestamps.append(int(grd_name[:-4]))
            ha_ground_latlons.append(np.asarray([float(g_lat), float(g_lon)]))
            # ha_aerial_latlons.append(np.asarray([float(s_lat), float(s_lon)]))
            ha_bearing_offsets.append(float(theta) * args.rotation)
            ha_translation_offsets.append(args.translation / math.sqrt(2) * np.asarray([float(gt_shift_u), float(gt_shift_v)]))
        ha_ground_latlons = np.asarray(ha_ground_latlons)
        # ha_aerial_latlons = np.asarray(ha_aerial_latlons)
        ha_timestamps = np.asarray(ha_timestamps)
        ha_bearing_offsets = np.asarray(ha_bearing_offsets)
        ha_translation_offsets = np.asarray(ha_translation_offsets)

        my_timestamps = np.asarray([f.timestamp for f in scene])

        frame_indices = np.argmin(np.abs(ha_timestamps[:, np.newaxis] - my_timestamps[np.newaxis, :]), axis=1)
        ha_frame_ids = [scene[i] for i in frame_indices]
        print(f"Using {len(ha_timestamps)} FL frames and {len(set(frame_indices))} lidar frames")
    elif not args.offsets is None:
        with open(os.path.join(args.offsets, f"offsets-{scene_name}.txt"), "r") as f:
            lines = f.readlines()[1:]
        lines = [l.strip() for l in lines]
        lines = [l.split(",") for l in lines if len(l) > 0]
        loaded_offsets = {tokens[0]: (float(tokens[1]), float(tokens[2]), float(tokens[3])) for tokens in lines}
        ha_frame_ids = scene
    else:
        loaded_offsets = None
        ha_frame_ids = scene



    tileloader = twm.LRUCached(googlemaps, 100)
    vars(tileloader)["name"] = googlemaps.name
    vars(tileloader)["zoom"] = googlemaps.zoom

    stream = iter(list(enumerate(ha_frame_ids)))
    stream, iterations_order = pl.order.save(stream)
    stream = pl.sync(stream) # Concurrent processing starts after this point

    @pl.unpack
    def load_frame(frame_index, frame_id):
        frame = frame_id.load()
        return grd.ground.AlignedFrameId(frame_id, frame_id.latlon, frame_id.bearing, model_constants.meters_per_pixel[-1]).load()
    stream = pl.map(load_frame, stream)
    stream = pl.queued(stream, workers=8, maxsize=4)
    stream = pl.order.load(stream, iterations_order)

    l2_errors = []
    lon_errors = []
    lat_errors = []
    offsets = []
    for frame_index, ground_frame in enumerate(tqdm.tqdm(stream, total=len(ha_frame_ids))):
        if not args.offsets is None and args.offsets == "shietal":
            translation_bearing = ground_frame.bearing + -math.degrees(math.atan2(ha_translation_offsets[frame_index][0], ha_translation_offsets[frame_index][1])) + 180.0
            translation_distance = np.linalg.norm(ha_translation_offsets[frame_index])
            rotation_offset = ha_bearing_offsets[frame_index]
        elif loaded_offsets is None:
            translation_bearing = np.random.uniform(-180.0, 180.0)
            translation_distance = np.random.uniform(0.0, args.translation)
            rotation_offset = np.random.uniform(-args.rotation, args.rotation)
        else:
            translation_bearing, translation_distance, rotation_offset = loaded_offsets[ground_frame.name]

        offsets.append((ground_frame.name, translation_bearing, translation_distance, rotation_offset))

        prior_latlon = cosy.np.geo.move_from_latlon(
            ground_frame.latlon,
            bearing=translation_bearing,
            distance=translation_distance,
        )
        prior_bearing = ground_frame.bearing + rotation_offset
        offset = cosy.np.geo.distance(prior_latlon, ground_frame.latlon)
        assert offset <= args.translation
        prediction, ag_frame = predictor(ground_frame, tileloader, prior_latlon, prior_bearing, silent=True)
        pred_aerial_to_bev = prediction.discrete_argmax(unit="pixels").inverse()
        true_aerial_to_bev = ag_frame.bevpixels_to_aerialpixels.inverse()



        translation_error = (true_aerial_to_bev.translation - pred_aerial_to_bev.translation) * prediction.meters_per_pixel

        l2_error = np.linalg.norm(translation_error)
        lon_error = abs(translation_error[0])
        lat_error = abs(translation_error[1])
        l2_errors.append(l2_error)
        lon_errors.append(lon_error)
        lat_errors.append(lat_error)

    l2_errors = np.asarray(l2_errors)
    lon_errors = np.asarray(lon_errors)
    lat_errors = np.asarray(lat_errors)

    with open(os.path.join(log.dir(), f"offsets-{scene_name}.txt"), "w") as f:
        f.write("Frame,Translation-Bearing(degree),Translation-Distance,Orientation-Offset(degree)\n")
        for o in offsets:
            f.write(", ".join([str(s) for s in o]) + "\n")

    def recall(errors, q):
        errors = np.asarray(errors)
        return float(np.count_nonzero(errors <= q)) / np.prod(errors.shape)
    print("Errors:")
    means = f"Mean: l2={np.mean(l2_errors)} lon={np.mean(lon_errors)} lat={np.mean(lat_errors)}\nMedian: l2={np.quantile(l2_errors, q=0.5)} lon={np.quantile(lon_errors, q=0.5)} lat={np.quantile(lat_errors, q=0.5)}"
    print(means)
    def recall_str(errors, name):
        return f"Recall {name}: r@0.25m={recall(errors, 0.25)} r@0.5m={recall(errors, 0.5)} r@1m={recall(errors, 1.0)} r@2m={recall(errors, 2.0)} r@3m={recall(errors, 3.0)} r@5m={recall(errors, 5.0)}"
    print(recall_str(l2_errors, "l2"))
    print(recall_str(lon_errors, "lon"))
    print(recall_str(lat_errors, "lat"))

    result += f"{scene_name}\n"
    result += means + "\n"
    result += f"{recall_str(l2_errors, 'l2')}\n"
    result += f"{recall_str(lon_errors, 'lon')}\n"
    result += f"{recall_str(lat_errors, 'lat')}\n"

    np.savez(os.path.join(log.dir(), f"errors-{scene_name}.npz"), l2_errors=l2_errors, lon_errors=lon_errors, lat_errors=lat_errors)

with open(os.path.join(log.dir(), "errors.txt"), "w") as f:
    f.write(result)
