#!/usr/bin/env python3

import sys, os, cosy, math
import numpy as np
import georegdata as grd

path = sys.argv[1]

def into_single_folder(p):
    l = [os.path.join(p, x) for x in os.listdir(p)]
    l = [x for x in l if os.path.isdir(x)]
    assert len(l) == 1
    return l[0]

epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")
epsg3857_to_epsg4326 = cosy.np.proj.Transformer("epsg:3857", "epsg:4326")

def curr_to_epsg3857(latlon, bearing):
    return cosy.np.Rigid(
        rotation=cosy.np.angle_to_rotation_matrix(epsg4326_to_epsg3857.transform_angle(math.radians(bearing))).astype("float64"),
        translation=epsg4326_to_epsg3857(latlon).astype("float64"),
    )

def evaluate(scenes, pseudolabels, path, stride, align=False):
    if not os.path.isdir(path):
        return
    def augment(frame_id):
        if align:
            return frame_id
        else:
            params = frame_id.get_params()
            params["ego_to_world"] = pseudolabels[(frame_id.name, f"googlemaps-zoom20")]
            return grd.ground.FrameId(**params)

    all_errors = []
    all_lon_errors = []
    all_lat_errors = []

    for scene_name in sorted(os.listdir(path)):
        if not os.path.isdir(os.path.join(path, scene_name)):
            continue
        print(f"Scene: {scene_name}")
        frame_ids = [augment(f) for f in scenes[scene_name]]
        gt_latlons = [f.latlon for f in frame_ids][1:][::stride]
        gt_bearings = [f.bearing for f in frame_ids][1:][::stride]

        trajectory = np.load(os.path.join(into_single_folder(os.path.join(path, scene_name)), "trajectory.npz"))
        pred_latlons = trajectory["latlons"]
        pred_bearings = trajectory["bearings"]

        gt_latlons = gt_latlons[:len(pred_latlons)]

        assert len(pred_latlons) == len(gt_latlons)

        if align:
            pred_positions_world = [epsg4326_to_epsg3857(latlon) for latlon in pred_latlons]
            gt_positions_world = [epsg4326_to_epsg3857(latlon) for latlon in gt_latlons]
            pred_to_gt = cosy.np.Rigid.least_squares(
                from_points=pred_positions_world,
                to_points=gt_positions_world,
            )
            pred_latlons_aligned = [epsg3857_to_epsg4326(pos) for pos in pred_to_gt(pred_positions_world)]
            gt_latlons_aligned = [epsg3857_to_epsg4326(pos) for pos in pred_to_gt.inverse()(gt_positions_world)]

            pred_latlons = pred_latlons_aligned

        lon_errors = []
        lat_errors = []
        for i in range(len(pred_latlons)):
            gt_to_epsg3857 = curr_to_epsg3857(gt_latlons[i], gt_bearings[i])
            pred_to_epsg3857 = curr_to_epsg3857(pred_latlons[i], pred_bearings[i])
            pred_to_gt = gt_to_epsg3857.inverse() * pred_to_epsg3857
            translation_error = np.abs(pred_to_gt(0))
            lon_errors.append(translation_error[0])
            lat_errors.append(translation_error[1])



        errors = np.asarray([cosy.np.geo.distance(p, g) for p, g in zip(pred_latlons, gt_latlons)])
        all_errors.extend(errors)
        all_lon_errors.extend(lon_errors)
        all_lat_errors.extend(lat_errors)
        print(f"Error l2={np.mean(errors)} lon={np.mean(lon_errors)} lat={np.mean(lat_errors)}")

        def recall(errors, q):
            errors = np.asarray(errors)
            return float(np.count_nonzero(errors <= q)) / np.prod(errors.shape)
        print(f"Recall: r@0.25m={recall(errors, 0.25)} r@0.5m={recall(errors, 0.5)} r@1m={recall(errors, 1.0)} r@2m={recall(errors, 2.0)} r@3m={recall(errors, 3.0)} r@5m={recall(errors, 5.0)}")



    print(f"Overall-Error l2={np.mean(all_errors)} lon={np.mean(all_lon_errors)} l2={np.mean(all_lat_errors)}")
    print()

with open(os.environ["PSEUDOLABELS"], "r") as f:
    lines = f.readlines()
    lines = [l.strip() for l in lines]
    lines = [l.split(",") for l in lines if len(l) > 0]
pseudolabel_kitti360_ego_to_world = {(line[0], line[1]): cosy.np.Rigid.from_matrix(np.asarray([float(x) for x in line[2:2 + 4 * 4]]).reshape((4, 4))) for line in lines}

kitti360 = grd.ground.kitti360.load(os.path.join(os.environ["GROUND_DATA"], "kitti360"), cache=True)
kitti360 = {s[0].scene_id: s for s in kitti360["karlsruhe"].scenes}
for p in os.listdir(path):
    if p.startswith("track-kitti360"):
        print(p)
        if "stride2" in p:
            stride = 2
        elif "stride3" in p:
            stride = 3
        elif "stride1" in p:
            stride = 1
        else:
            assert False
        evaluate(kitti360, pseudolabel_kitti360_ego_to_world, os.path.join(path, p), stride=stride, align=True)
