#!/usr/bin/env python3

import os, sys, math, argparse, tinylogdir, tinyobserver, timerun, random, imageio, tqdm, types, georeg, zlib, re, cosy, pickle, time
import numpy as np
import tinypl as pl
from collections import defaultdict
from functools import partial
from columnar import columnar

parser = argparse.ArgumentParser()
parser.add_argument("--log", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

log = tinylogdir.LogDir(args.log, mode="timestamp")

# These imports require tensorflow and write lots of logging output
import tiledwebmaps as twm
import georegdata as grd
import tfcv
import tensorflow as tf
import tensorflow_addons as tfa

import mlog.client
mlog_session = mlog.client.Session(log.dir("mlog"))

# tf.random.set_seed(42)
tf.keras.utils.set_random_seed(42)
# tf.config.experimental.enable_op_determinism()


print("Loading datasets...")
final_meters_per_pixel = float(os.environ["METERS_PER_PIXEL"])
mlog_session["model"]["final-meters-per-pixel"] = final_meters_per_pixel

aerial_final_shape = np.asarray([int(os.environ["AERIAL_FINAL_SHAPE"]), int(os.environ["AERIAL_FINAL_SHAPE"])])
mlog_session["model"]["final-aerial-shape"] = aerial_final_shape.tolist()

bev_final_shape = np.asarray([int(os.environ["BEV_FINAL_SHAPE"]), int(os.environ["BEV_FINAL_SHAPE"])])
mlog_session["model"]["final-bev-shape"] = bev_final_shape.tolist()

ground_attn_strides = np.asarray([int(s) for s in os.environ["GROUND_ATTN_STRIDES"].split(",")])
mlog_session["model"]["ground-attn-strides"] = ground_attn_strides.tolist()

aerial_attn_strides = np.asarray([int(s) for s in os.environ["AERIAL_ATTN_STRIDES"].split(",")])
mlog_session["model"]["aerial-attn-strides"] = aerial_attn_strides.tolist()

aerial_stride = int(os.environ["AERIAL_STRIDE"])
mlog_session["model"]["aerial-stride"] = aerial_stride




train_max_offset_meters = float(os.environ["MAX_OFFSET_METERS"]) if "MAX_OFFSET_METERS" in os.environ else 30.0
val_max_offset_meters = train_max_offset_meters
def check(max_offset_meters):
    max_offset_pixels = int(math.ceil(max_offset_meters / final_meters_per_pixel))
    if max_offset_pixels > min(aerial_final_shape[0], aerial_final_shape[1]) / 2:
        print(f"Pixel offset {max_offset_pixels} is more than half of aerial embedding size {aerial_final_shape}")
        sys.exit(-1)
mlog_session["training"]["offset"] = f"{train_max_offset_meters:.1f}m"
mlog_session["validation"]["offset"] = f"{val_max_offset_meters:.1f}m"

corr_add_pixels = float(os.environ["CORR_ADD_PIXELS"]) if "CORR_ADD_PIXELS" in os.environ else 0
mlog_session["corr-add-pixels"] = corr_add_pixels

train_corr_shape = 2 * np.ceil(np.asarray([train_max_offset_meters, train_max_offset_meters]) / final_meters_per_pixel + corr_add_pixels).astype("int32")
val_corr_shape = 2 * np.ceil(np.asarray([val_max_offset_meters, val_max_offset_meters]) / final_meters_per_pixel + corr_add_pixels).astype("int32")
if np.any(train_corr_shape == 0):
    print("Got correlation shape 0")
    sys.exit(-1)
if np.any(val_corr_shape == 0):
    print("Got correlation shape 0")
    sys.exit(-1)
mlog_session["training"]["corr_shape"] = train_corr_shape.tolist()
mlog_session["validation"]["corr_shape"] = val_corr_shape.tolist()

train_points_num = 2 ** 16
mlog_session["training"]["points_num"] = train_points_num
val_points_num = 2 ** 16
mlog_session["validation"]["points_num"] = val_points_num

maps_path = os.environ["AERIAL_DATA"]
wait_after_error = 5.0
retries = 100

# Terms of Use:
# DCGIS: https://octo.dc.gov/node/1165116
# MASSGIS: https://wiki.openstreetmap.org/wiki/MassGIS
# Stratmap: https://tnris.org/stratmap/
# Googlemaps: https://about.google/brand-resource-center/products-and-services/geo-guidelines/
# Bingmaps: https://www.microsoft.com/en-us/maps/product

googlemaps = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "googlemaps"))
vars(googlemaps)["zoom"] = 20
vars(googlemaps)["name"] = "googlemaps"

bingmaps = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "bingmaps"))
vars(bingmaps)["zoom"] = 20
vars(bingmaps)["name"] = "bingmaps"

massgis21 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "massgis21"))
vars(massgis21)["zoom"] = 20
vars(massgis21)["name"] = "massgis21"

massgis19 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "massgis19"))
vars(massgis19)["zoom"] = 20
vars(massgis19)["name"] = "massgis19"

stratmap18 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "stratmap18"))
vars(stratmap18)["zoom"] = 20
vars(stratmap18)["name"] = "stratmap18"

stratmap19 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "stratmap19"))
vars(stratmap19)["zoom"] = 20
vars(stratmap19)["name"] = "stratmap19"

stratmap21 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "stratmap21"))
vars(stratmap21)["zoom"] = 20
vars(stratmap21)["name"] = "stratmap21"

dcgis17 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "dcgis17"))
vars(dcgis17)["zoom"] = 19
vars(dcgis17)["name"] = "dcgis17"

dcgis19 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "dcgis19"))
vars(dcgis19)["zoom"] = 19
vars(dcgis19)["name"] = "dcgis19"

dcgis21 = twm.Directory(layout=twm.Layout.XYZ((256, 256)), path=os.path.join(maps_path, "dcgis21"))
vars(dcgis21)["zoom"] = 19
vars(dcgis21)["name"] = "dcgis21"

cache = True

argoverse_v1 = grd.ground.argoverse_v1.load(os.path.join(os.environ["GROUND_DATA"], "argoverse-v1"), cache=cache)
nuscenes = grd.ground.nuscenes.load(os.path.join(os.environ["GROUND_DATA"], "nuscenes"), cache=cache)
ford_avdata = grd.ground.ford_avdata.load(os.path.join(os.environ["GROUND_DATA"], "ford-avdata"), cache=cache)
argoverse_v2 = grd.ground.argoverse_v2.load(os.path.join(os.environ["GROUND_DATA"], "argoverse-v2"), cache=cache)
lyft = grd.ground.lyft.load(os.path.join(os.environ["GROUND_DATA"], "lyft"), cache=cache)
pandaset = grd.ground.pandaset.load(os.path.join(os.environ["GROUND_DATA"], "pandaset"), cache=cache)

train_datasets = [
    (nuscenes["boston-seaport"], [bingmaps, googlemaps, massgis19, massgis21]),
    (argoverse_v1["PIT"], [bingmaps, googlemaps]),
    (argoverse_v1["MIA"], [bingmaps, googlemaps]),
    (argoverse_v2["MIA"], [bingmaps, googlemaps]),
    (argoverse_v2["PIT"], [bingmaps, googlemaps]),
    (argoverse_v2["DTW"], [bingmaps, googlemaps]),
    (argoverse_v2["WDC"], [bingmaps, googlemaps, dcgis17, dcgis19, dcgis21]),
    (argoverse_v2["ATX"], [bingmaps, googlemaps, stratmap18, stratmap19, stratmap21]),
    (ford_avdata["detroit"].slice_scenes(slice(0, 18)), [bingmaps, googlemaps]),
]

if "VALIDATE" in os.environ and int(os.environ["VALIDATE"]) == 1:
    n = 8
    val_datasets = [
        # (lyft["Palo Alto"].slice_frames_per_scene(slice(None, None, 144 // n)), [bingmaps]),
        # (pandaset["palo alto san mateo"].slice_frames_per_scene(slice(None, None, 8 // n)), [bingmaps]),
        # (pandaset["san francisco"].slice_frames_per_scene(slice(None, None, 16 // n)), [bingmaps]),
        # (argoverse_v2["PAO"].slice_frames_per_scene(slice(None, None, 32 // n)), [bingmaps]),
    ]
else:
    val_datasets = [
    ]

train_datasets = [(x, [t for t in l if not t is None]) for x, l in train_datasets]
val_datasets = [(x, [t for t in l if not t is None]) for x, l in val_datasets]

max_cameras = 9
assert max_cameras >= max([len(frame_id.cameras) for d, _ in train_datasets + val_datasets for scene in d.scenes for frame_id in scene])

model_constants = georeg.model.base.ModelConstants(
    bev_final_shape=bev_final_shape,
    aerial_final_shape=aerial_final_shape,
    final_meters_per_pixel=final_meters_per_pixel,
    ground_attn_strides=ground_attn_strides,
    aerial_attn_strides=aerial_attn_strides,
    aerial_stride=aerial_stride,
    max_cameras=max_cameras,
)

mlog_session["data"]["training"]["datasets"] = [dataset.fullname + "-" + tile_loader.name for dataset, tile_loaders in train_datasets for tile_loader in tile_loaders]
mlog_session["data"]["validation"]["datasets"] = [dataset.fullname + "-" + tile_loader.name for dataset, tile_loaders in val_datasets for tile_loader in tile_loaders]

meters_per_chunk = float(os.environ["METERS_PER_CHUNK"]) if "METERS_PER_CHUNK" in os.environ else 75.0
mlog_session["data"]["meters_per_chunk"] = meters_per_chunk
chunk_order = str(os.environ["CHUNK_ORDER"]) if "CHUNK_ORDER" in os.environ else "random-uniform"
mlog_session["data"]["chunk-order"] = chunk_order

print("|SPLIT DATASET      LOCATION                       #SCENES #FRAMES #MAPS #CHUNKS|")
for d, split in [(train_datasets, "train"), (val_datasets, "val")]:
    for dataset, tile_loaders in d:
        frame_ids = [frame_id for scene in dataset.scenes for frame_id in scene]
        chunks = len(georeg.data.AerialSampler(frame_ids, meters_per_chunk))
        print(f"|{split:<5} {dataset.name:<12} {dataset.location:<30} {len(dataset.scenes):>7} {len(frame_ids):>7} {len(tile_loaders):>5} {chunks:>7}|")

pseudolabel_file = str(os.environ["PSEUDOLABELS"]) if "PSEUDOLABELS" in os.environ else None
if not pseudolabel_file is None:
    print("Loading pseudolabels...")
    with open(pseudolabel_file, "r") as f:
        lines = f.readlines()
        lines = [l.strip() for l in lines]
        lines = [l.split(",") for l in lines if len(l) > 0]
    pseudolabel_ego_to_world = {(line[0], line[1]): cosy.np.Rigid.from_matrix(np.asarray([float(x) for x in line[2:2 + 4 * 4]]).reshape((4, 4))) for line in lines}
    pseudolabel_scores = {(line[0], line[1]): float(line[-1]) for line in lines}

    if "PSEUDOLABEL_USEPOSE" in os.environ and int(os.environ["PSEUDOLABEL_USEPOSE"]) == 0:
        pseudolabel_ego_to_world = None
else:
    pseudolabel_scores = None
    pseudolabel_ego_to_world = None

if not pseudolabel_ego_to_world is None:
    mlog_session["data"]["pseudolabels-egotoworld"] = pseudolabel_file

print("Preparing epochs...")

train_angles_range = math.radians(float(os.environ["TRAIN_MODEL_ANGLES_RANGE"]) if "TRAIN_MODEL_ANGLES_RANGE" in os.environ else 10.0)
train_angles_num = int(os.environ["TRAIN_MODEL_ANGLES_NUM"]) if "TRAIN_MODEL_ANGLES_NUM" in os.environ else 3
mlog_session["training"]["angles"]["num"] = train_angles_num
mlog_session["training"]["angles"]["range"] = math.degrees(train_angles_range)

val_angles_range = math.radians(float(os.environ["VAL_MODEL_ANGLES_RANGE"]) if "VAL_MODEL_ANGLES_RANGE" in os.environ else 10.0)
val_angles_num = int(os.environ["VAL_MODEL_ANGLES_NUM"]) if "VAL_MODEL_ANGLES_NUM" in os.environ else 21
mlog_session["validation"]["angles"]["num"] = val_angles_num
mlog_session["validation"]["angles"]["range"] = math.degrees(val_angles_range)

train_batchsize = int(os.environ["TRAIN_BATCHSIZE"]) if "TRAIN_BATCHSIZE" in os.environ else 1
val_batchsize = int(os.environ["VAL_BATCHSIZE"]) if "VAL_BATCHSIZE" in os.environ else 4
mlog_session["training"]["batches"] = train_batchsize
mlog_session["validation"]["batches"] = val_batchsize

debug_period = (100 if args.debug or "geo_train_REMOVE" in args.log else 1000) // train_batchsize # iterations
report_period = 500 // train_batchsize # iterations
val_period = (int(os.environ["VAL_PERIOD"]) if "VAL_PERIOD" in os.environ else 3000) // train_batchsize # iterations
samples = int(os.environ["SAMPLES"]) if "SAMPLES" in os.environ else 100000
mlog_session["training"]["samples"] = samples

train_cams = str(os.environ["TRAIN_CAMS"]) if "TRAIN_CAMS" in os.environ else "all"
val_cams = str(os.environ["VAL_CAMS"]) if "VAL_CAMS" in os.environ else "all"
mlog_session["training"]["cams"] = train_cams
mlog_session["validation"]["cams"] = val_cams
def filter_cameras(frame_id, mode):
    if mode == "all":
        return frame_id
    elif mode == "front":
        params = frame_id.get_params()
        front_cams = []
        for camera_id in params["cameras"]:
            cam_yaw = camera_id.ego_to_camera.inverse().rotation @ np.asarray([0.0, 0.0, 1.0])
            cam_yaw = cosy.np.angle(np.asarray([1.0, 0.0]), cam_yaw[:2])
            if abs(math.degrees(cam_yaw)) < 10.0:
                front_cams.append(camera_id)
        assert len(front_cams) > 0
        def fov(camera_id):
            # resolution = 0.5 * (camera_id.resolution[0] + camera_id.resolution[1])
            # focal_length = 0.5 * (camera_id.intr[0, 0] + camera_id.intr[1, 1])
            resolution = camera_id.resolution[1]
            focal_length = camera_id.intr[0, 0]
            return 0.5 * math.atan(0.5 * resolution / focal_length)
        front_cams = sorted(front_cams, key=fov)
        params["cameras"] = [front_cams[-1]]
        return grd.ground.FrameId(**params)
    else:
        assert False
def preprocess_frame_id(frame_id, train):
    if not pseudolabel_ego_to_world is None:
        ego_to_world = pseudolabel_ego_to_world[(frame_id.ground_frame_id.name, frame_id.aerial_frame_id.name)]
    def preprocess_ground_frame_id(ground_frame_id):
        ground_frame_id = filter_cameras(ground_frame_id, mode=train_cams if train else val_cams)

        params = ground_frame_id.get_params()
        if not pseudolabel_ego_to_world is None:
            params["ego_to_world"] = ego_to_world
        ground_frame_id = grd.ground.FrameId(**params)

        return ground_frame_id
    def preprocess_aligned_ground_frame_id(aligned_ground_frame_id):
        return grd.ground.AlignedFrameId(preprocess_ground_frame_id(aligned_ground_frame_id.base_frame_id), aligned_ground_frame_id.latlon, aligned_ground_frame_id.bearing, aligned_ground_frame_id.meters_per_pixel)
    return grd.FrameId(
        preprocess_aligned_ground_frame_id(frame_id.ground_frame_id),
        frame_id.aerial_frame_id,
    )
train_max_bg_offset_degrees = float(os.environ["TRAIN_MAX_BG_DEGREES"]) if "TRAIN_MAX_BG_DEGREES" in os.environ else 0.0
train_max_ba_offset_degrees = float(os.environ["TRAIN_MAX_BA_DEGREES"]) if "TRAIN_MAX_BA_DEGREES" in os.environ else 10.0
val_max_ba_offset_degrees = float(os.environ["VAL_MAX_BA_DEGREES"]) if "VAL_MAX_BA_DEGREES" in os.environ else 10.0
assert train_max_bg_offset_degrees == 0
mlog_session["training"]["bg-offset-angle"] = f"{train_max_bg_offset_degrees:.1f}deg"
mlog_session["training"]["ba-offset-angle"] = f"{train_max_ba_offset_degrees:.1f}deg"
mlog_session["validation"]["ba-offset-angle"] = f"{val_max_ba_offset_degrees:.1f}deg"
def str_to_float(s):
    return float(zlib.crc32(s.encode("utf-8")) & 0xffffffff) / (2 ** 32)
def random_latlon_offset(latlon, train, seed=None):
    max_offset_meters = train_max_offset_meters if train else val_max_offset_meters
    if seed is None:
        bearing = random.uniform(0.0, 1.0)
        distance = random.uniform(0.0, 1.0)
    else:
        bearing = str_to_float(seed + "1")
        distance = str_to_float(seed + "2")
    bearing = bearing * 360.0
    distance = distance * max_offset_meters

    latlon = cosy.np.geo.move_from_latlon(latlon, bearing, distance)
    return latlon
def random_bearing_offset(bearing, max_offset_degrees, seed=None):
    if seed is None:
        offset = random.uniform(0.0, 1.0)
    else:
        offset = str_to_float(seed)
    offset = (offset * 2 - 1) * max_offset_degrees

    return bearing + offset
augment_camera_angle = math.radians(float(os.environ["AUGMENT_CAMERA_ANGLE"])) if "AUGMENT_CAMERA_ANGLE" in os.environ else 0.0
augment_camera_scale = float(os.environ["AUGMENT_CAMERA_SCALE"]) if "AUGMENT_CAMERA_SCALE" in os.environ else 1.0
fix_intr = float(os.environ["FIX_INTR"]) if "FIX_INTR" in os.environ else None
assert augment_camera_angle >= 0
assert augment_camera_scale >= 1
mlog_session["data"]["augment"]["camera"]["angle"] = f"{math.degrees(augment_camera_angle)}deg"
mlog_session["data"]["augment"]["camera"]["scale"] = augment_camera_scale
mlog_session["data"]["augment"]["camera"]["fix-intr"] = fix_intr
assert fix_intr is None or (augment_camera_angle == 0 and augment_camera_scale == 1)
def augment_camera_params(frame_id, train):
    params = frame_id.get_params()
    if not fix_intr is None:
        params["cameras"] = [grd.ground.FixedIntrCameraId(camera_id, focal_length=fix_intr) for camera_id in frame_id.cameras]
    else:
        if train:
            def angle():
                angle = random.uniform(-augment_camera_angle, augment_camera_angle)
                return angle
            def scale():
                log_scale = math.log(augment_camera_scale)
                assert log_scale >= 0
                scale = random.uniform(-log_scale, log_scale)
                scale = math.exp(scale)
                return scale
        else:
            def angle():
                return 0.0
            def scale():
                return 1.0

        params["cameras"] = [grd.ground.AugmentedCameraId(camera_id, angle=angle(), scale=scale()) for camera_id in frame_id.cameras]
    return grd.ground.FrameId(**params)
def datasets_to_frame_ids(datasets, train, epochs=None, samples=None):
    if samples == 0 or epochs == 0:
        return []
    if train and not pseudolabel_scores is None:
        dataprune_drop_easy = float(os.environ["DATAPRUNE_DROP_EASY"]) if "DATAPRUNE_DROP_EASY" in os.environ else 0.0
        dataprune_drop_hard = float(os.environ["DATAPRUNE_DROP_HARD"]) if "DATAPRUNE_DROP_HARD" in os.environ else 0.0
        mlog_session["data"]["prune"]["drop-easy"] = dataprune_drop_easy
        mlog_session["data"]["prune"]["drop-hard"] = dataprune_drop_hard

        dataprune_scores = pseudolabel_scores
    else:
        dataprune_scores = None


    num_without_dataprune_score = 0
    wrapped = []
    for dataset, tile_loaders in datasets:
        for tile_loader in tile_loaders:
            for scene in dataset.scenes:
                for ground_frame_id in scene:
                    if not dataprune_scores is None:
                        score_key = (ground_frame_id.name, f"{tile_loader.name}-zoom{tile_loader.zoom}")
                        if not score_key in dataprune_scores:
                            num_without_dataprune_score += 1
                            continue
                        score = dataprune_scores[score_key]
                    else:
                        score = None

                    wrapped.append(types.SimpleNamespace(
                        latlon=ground_frame_id.latlon,
                        ground_frame_id=ground_frame_id,
                        tile_loader=tile_loader,
                        score=score,
                    ))
    if train:
        original_num = len(wrapped)
        if not dataprune_scores is None:
            if num_without_dataprune_score > 0:
                print(f"Training: Dropping {num_without_dataprune_score} samples without pruning score")

            wrapped = sorted(wrapped, key=lambda w: w.score)
            drop_easy = int(len(wrapped) * dataprune_drop_easy)
            drop_hard = int(len(wrapped) * dataprune_drop_hard)
            wrapped = wrapped[drop_easy:len(wrapped) - drop_hard]
            print(f"Training: Pruning {drop_easy} easy and {drop_hard} hard samples from {original_num} input samples")

        original_num = len(wrapped)
        wrapped = georeg.data.AerialSampler(wrapped, meters_per_chunk, order=chunk_order)
        print(f"Training: Using {original_num} frames split into {len(wrapped)} aerial chunks at {meters_per_chunk}x{meters_per_chunk} meters")
    else:
        assert num_without_dataprune_score == 0
        print(f"Validation: Using {len(wrapped)} frames")

    result = []
    r = random.Random(13)
    epoch = 0
    sample = 0
    while True:
        result_epoch = []
        for w in (wrapped if not train else r.sample(wrapped, len(wrapped))):
            ground_latlon = w.ground_frame_id.latlon
            ground_bearing = w.ground_frame_id.bearing # random_bearing_offset(w.ground_frame_id.bearing, max_offset_degrees=train_max_bg_offset_degrees if train else 0.0, seed=w.ground_frame_id.name + "bearing2_" + str(epoch))

            max_ba_offset_degrees = train_max_ba_offset_degrees if train else val_max_ba_offset_degrees
            if train:
                aerial_latlon = random_latlon_offset(ground_latlon, train, seed=w.ground_frame_id.name + "latlon_" + str(epoch))
                aerial_bearing = random_bearing_offset(ground_bearing, max_offset_degrees=max_ba_offset_degrees, seed=w.ground_frame_id.name + "bearing1_" + str(epoch))
            else:
                aerial_latlon = random_latlon_offset(ground_latlon, train, seed=w.ground_frame_id.name + "latlon")
                aerial_bearing = random_bearing_offset(ground_bearing, max_offset_degrees=max_ba_offset_degrees, seed=w.ground_frame_id.name + "bearing")

            result_epoch.append(grd.FrameId(
                grd.ground.AlignedFrameId(augment_camera_params(w.ground_frame_id, train=train), ground_latlon, ground_bearing, model_constants.meters_per_pixel[-1]),
                grd.aerial.FrameId(w.tile_loader, w.tile_loader.name, w.tile_loader.zoom, aerial_latlon, aerial_bearing, model_constants.meters_per_pixel[-1] / model_constants.aerial_stride, model_constants.aerial_image_shape),
            ))

            sample += 1
            if not samples is None and sample == samples:
                break
        epoch += 1
        result.append(result_epoch)
        if (not epochs is None and epoch == epochs) or (not samples is None and sample == samples):
            break

    epochs = epoch if sample == 0 else (epoch + 1)
    if not train:
        result = epochs * result

    return result

schedule = types.SimpleNamespace()
schedule.train = types.SimpleNamespace()
schedule.train.iteration = tinyobserver.counter.Counter(1)
train_frame_ids = datasets_to_frame_ids(train_datasets, train=True, samples=samples)
schedule.train.iterations = sum(len(l) // train_batchsize for l in train_frame_ids)
schedule.train.epoch = 0

if len(val_datasets) > 0:
    schedule.val = types.SimpleNamespace()
    schedule.val.iteration = tinyobserver.counter.Counter(1)
    schedule.val.epoch = tinyobserver.counter.Counter(val_period) # start_value=val_period - 1
    schedule.val.epoch_begin = tinyobserver.Observable()
    schedule.val.epoch_end = tinyobserver.Observable()
    schedule.train.iteration.subscribe(lambda frame: schedule.val.epoch(), precedence=9999)
    schedule.val.epochs = schedule.train.iterations // val_period + 1
    mlog_session["validation"]["epochs"] = schedule.val.epochs
    val_frame_ids = datasets_to_frame_ids(val_datasets, epochs=schedule.val.epochs, train=False)
    schedule.val.samples_per_epoch = len(val_frame_ids[0])
    schedule.val.iterations_per_epoch = (schedule.val.samples_per_epoch + val_batchsize - 1) // val_batchsize







print("Building model...")
variant = str(os.environ["VARIANT"])
if variant == "lidar":
    variant = georeg.model.variant.Lidar(
        model_constants=model_constants,
        ground_point_pooling=str(os.environ["POINT_POOL"]) if "POINT_POOL" in os.environ else "mean",
        model_params=mlog_session["model"],
    )
elif variant == "point-pillars":
    variant = georeg.model.variant.PointPillar(
        model_constants=model_constants,
        model_params=mlog_session["model"],
    )
else:
    assert False
model, train_model, preprocess_aerial, preprocess_ground, loss_fn, debug_fn = georeg.model.base.build(
    model_constants=model_constants,
    variant=variant,
    model_params=mlog_session["model"],
)
num_parameters = np.sum([np.prod(v.get_shape().as_list()) for v in model.trainable_variables])
print(f"Model has {num_parameters} trainable parameters")

checkpoint_train_dir = str(os.environ["CHECKPOINT"]) if "CHECKPOINT" in os.environ else None
if not checkpoint_train_dir is None:
    checkpoint_model = tf.keras.models.load_model(os.path.join(checkpoint_train_dir, "saved_model"), compile=False)
    checkpoint_variables = {v.name: v for v in checkpoint_model.trainable_variables}
    used_names = set()
    for v in model.trainable_variables:
        v.assign(checkpoint_variables[v.name])
        used_names.add(v.name)
    assert used_names == set(checkpoint_variables.keys())



print("Starting data pipeline...")

augment_color = int(os.environ["AUGMENT_COLOR"]) == 1 if "AUGMENT_COLOR" in os.environ else True
mlog_session["training"]["augment-color"] = augment_color

train_stream, train_epoch_end_marker, train_data_pipeline_state = georeg.data.stream(
    frame_ids=train_frame_ids,
    frames_to_model_input=partial(georeg.model.io.ModelInput.from_frames,
        angles=np.linspace(-train_angles_range, train_angles_range, num=train_angles_num) if train_angles_num > 1 else [0.0],
        align_bearing=True,
        augment_aerial_image=partial(georeg.data.augment_aerial_image, get_epoch=lambda: schedule.train.epoch) if augment_color else None,
        augment_ground_image=partial(georeg.data.augment_ground_image, get_epoch=lambda: schedule.train.epoch) if augment_color else None,
        shuffle_points=True,
        points_num=train_points_num,
        corr_shape=train_corr_shape,
        training_mask=True,
    ),
    frames_to_loss_input=partial(georeg.model.io.LossInput.from_frames,
    ),
    preprocess=partial(preprocess_frame_id, train=True),
    batchsize=train_batchsize,
    split="training",
    mlog_session=mlog_session,
    save_first_num=16,
    save_first_path=log.dir("ready_frames/training"),
    workers=[(12, 6), (4, 2)],
)
mean_train_time = tf.keras.metrics.Mean()
def run():
    while schedule.train.epoch < len(train_frame_ids):
        for batch_data in pl.marker.until(train_stream, marker=train_epoch_end_marker):
            frames, model_input, loss_input = batch_data
            with timerun.Timer() as timer:
                schedule.train.iteration(batch_data)
            if schedule.train.iteration.get() > 10:
                mean_train_time.update_state(timer.duration.timedelta.total_seconds() / len(frames))
        schedule.train.epoch += 1
        save_model()
    save_model()
    if len(val_datasets) > 0:
        validate()

if len(val_datasets) > 0:
    val_frames_to_model_input=partial(georeg.model.io.ModelInput.from_frames,
        angles=np.linspace(-val_angles_range, val_angles_range, num=val_angles_num) if val_angles_num > 1 else [0.0],
        align_bearing=False,
        augment_aerial_image=None,
        augment_ground_image=None,
        points_num=val_points_num,
        corr_shape=val_corr_shape,
        training_mask=False,
        shuffle_points=True,
    )
    val_frames_to_loss_input=partial(georeg.model.io.LossInput.from_frames,
    )
    val_stream, val_epoch_end_marker, val_data_pipeline_state = georeg.data.stream(
        frame_ids=val_frame_ids,
        frames_to_model_input=val_frames_to_model_input,
        frames_to_loss_input=val_frames_to_loss_input,
        preprocess=partial(preprocess_frame_id, train=False),
        batchsize=val_batchsize,
        split="validation",
        mlog_session=mlog_session,
        save_first_num=16,
        save_first_path=log.dir("ready_frames/validation"),
        workers=[(18, 6), (4, 2)],
    )
    mean_val_time = tf.keras.metrics.Mean()
    def validate():
        with timerun.Timer() as timer:
            schedule.val.epoch_begin()
            for frame in pl.marker.until(val_stream, marker=val_epoch_end_marker):
                schedule.val.iteration(frame)
            schedule.val.epoch_end()
        mean_val_time.update_state(timer.duration.timedelta.total_seconds())
        mlog_session.commit()
    schedule.val.epoch.subscribe(validate)





trainable_variables = train_model.trainable_variables

layer_decay = float(os.environ["LAYER_DECAY"]) if "LAYER_DECAY" in os.environ else 1.0
mlog_session["training"]["layer_decay"] = layer_decay

ground_layer_schedule_multiplier = float(os.environ["GROUND_LR_FACTOR"]) if "GROUND_LR_FACTOR" in os.environ else 1.0
aerial_layer_schedule_multiplier = float(os.environ["AERIAL_LR_FACTOR"]) if "AERIAL_LR_FACTOR" in os.environ else 1.0
mlog_session["training"]["ground_layer_schedule_multiplier"] = ground_layer_schedule_multiplier
mlog_session["training"]["aerial_layer_schedule_multiplier"] = aerial_layer_schedule_multiplier

pretrained_layer_schedule_multiplier = float(os.environ["PRETRAINED_LR_FACTOR"]) if "PRETRAINED_LR_FACTOR" in os.environ else 1.0
mlog_session["training"]["pretrained_layer_schedule_multiplier"] = pretrained_layer_schedule_multiplier

ground_blocks = [l for l in train_model.layers if re.search("ground/.*/unit([0-9]*)$", l.name)]
aerial_blocks = [l for l in train_model.layers if re.search("aerial/.*/unit([0-9]*)$", l.name)]
def layer_schedule_multiplier(v):
    multiplier = 1.0
    if layer_decay < 1.0:
        assert len(ground_blocks) > 1 and len(aerial_blocks) > 1
        match = re.search("(ground-backbone/.*/unit[0-9]*)", v.name)
        if match:
            block_index = next(i for i, l in enumerate(ground_blocks) if match.group(1) in l.name)
            multiplier = layer_decay ** (len(ground_blocks) - block_index - 1)
        else:
            match = re.search("(aerial-backbone/.*/unit[0-9]*)", v.name)
            if match:
                block_index = next(i for i, l in enumerate(aerial_blocks) if match.group(1) in l.name)
                multiplier = layer_decay ** (len(aerial_blocks) - block_index - 1)
    if "ground" in v.name:
        multiplier = multiplier * ground_layer_schedule_multiplier
    elif "aerial" in v.name:
        multiplier = multiplier * aerial_layer_schedule_multiplier
    if "ground-backbone" in v.name or "aerial-backbone" in v.name:
        multiplier = multiplier * pretrained_layer_schedule_multiplier
    return multiplier
layer_schedule_multipliers = [layer_schedule_multiplier(v) for v in trainable_variables]

model_ema_decay = float(os.environ["EMA_DECAY"]) if "EMA_DECAY" in os.environ else 0.999
mlog_session["model"]["ema-decay"] = model_ema_decay
if model_ema_decay > 0:
    with tf.device("CPU:0"):
        model_ema = tf.train.ExponentialMovingAverage(decay=model_ema_decay)
        model_ema.apply(trainable_variables)
    def swap(a, b):
        a.assign_add(b)
        b.assign(a - b)
        a.assign_sub(b)
    def model_ema_swap():
        for v in trainable_variables:
            swap(v, model_ema.average(v))
else:
    model_ema = None
    def model_ema_swap():
        pass

assert len(tf.config.list_logical_devices("GPU")) == 1
def get_memory_usage():
    num_bytes = tf.config.experimental.get_memory_info("GPU:0")["peak"]
    tf.config.experimental.reset_memory_stats("GPU:0")
    return num_bytes
def format_bytes(size):
    size = float(size)
    power = 2 ** 10
    n = 0
    power_labels = {0 : "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB"}
    while size >= power and n < len(power_labels) - 1:
        size /= power
        n += 1
    return f"{size:.1f}{power_labels[n]}"

mlog_session.commit()











print("Building tensorflow functions...")
lr_schedule = str(os.environ["SCHEDULE"]) if "SCHEDULE" in os.environ else "constant"
if False:
    def schedule_multiplier(step, total):
        if step < 5000:
            return 1.0
        elif step < 10000:
            return 1.0 / 10
        else:
            return 1.0 / 100
    mlog_session["training"]["schedule_multiplier"] = "drop-at-5000-10000"
if False:
    def schedule_multiplier(step, total):
        return (0.1 ** (float(step) / 1000))
    mlog_session["training"]["schedule_multiplier"] = "exp0.1-every1000"
if lr_schedule == "drop":
    drop = 10000
    def schedule_multiplier(step, total):
        if step < drop:
            return 1.0
        else:
            return 1.0 / 10
    mlog_session["training"]["schedule_multiplier"] = f"drop-at-{drop}"
elif lr_schedule == "constant":
    def schedule_multiplier(step, total):
        return 1.0
    mlog_session["training"]["schedule_multiplier"] = "constant"
elif lr_schedule == "cosine":
    def schedule_multiplier(step, total):
        return 0.5 * (math.cos(math.pi * float(step) / total) + 1.0)
    mlog_session["training"]["schedule_multiplier"] = "cosine_decay"
elif lr_schedule == "drop2500-drop10000-drop40000":
    def schedule_multiplier(step, total):
        if step < 2500:
            return 1.0
        elif step < 10000:
            return 0.3
        elif step < 40000:
            return 0.1
        else:
            return 0.03
    mlog_session["training"]["schedule_multiplier"] = "drop2500-drop10000-drop40000"
elif (match := re.match("poly-(.*)", lr_schedule)):
    power = float(match.group(1))
    def schedule_multiplier(step, total):
        return (1.0 - float(step) / float(total)) ** power
    mlog_session["training"]["schedule_multiplier"] = lr_schedule
else:
    assert False

# Warmup
warmup_steps = int(os.environ["WARMUP"]) if "WARMUP" in os.environ else 0
if warmup_steps > 0:
    warmup_min = 1e-2
    inner_schedule_multiplier = schedule_multiplier
    def schedule_multiplier(step, total, inner=inner_schedule_multiplier):
        if step < warmup_steps:
            return warmup_min + (float(step + 1) / warmup_steps) * (1.0 - warmup_min)
        else:
            return inner(step - warmup_steps, total - warmup_steps)
mlog_session["training"]["warmup"] = warmup_steps

min_schedule_multiplier = 1e-4
max_schedule_multiplier = 1.0
inner_schedule_multiplier = schedule_multiplier
def schedule_multiplier(inner=inner_schedule_multiplier):
    return inner(schedule.train.iteration.get(), max(schedule.train.iterations, 1)) * (max_schedule_multiplier - min_schedule_multiplier) + min_schedule_multiplier

base_learning_rate = float(os.environ["LR"])
mlog_session["training"]["learning_rate"] = base_learning_rate
def learning_rate():
    return schedule_multiplier() * base_learning_rate
optimizer = str(os.environ["OPT"]).lower() if "OPT" in os.environ else "radam"
mlog_session["training"]["optimizer"] = optimizer
if optimizer == "adam":
    optimizer = tf.keras.optimizers.Adam()
elif optimizer == "radam":
    optimizer = tfa.optimizers.RectifiedAdam()
else:
    assert False

weight_decay = float(os.environ["WEIGHT_DECAY"]) if "WEIGHT_DECAY" in os.environ else 1e-3
mlog_session["training"]["weight_decay"] = weight_decay

layer_weight_decay = []
for v in trainable_variables:
    w = (weight_decay if "/kernel:" in v.name else 0.0)
    layer_weight_decay.append(w)




def get_report_step():
    return schedule.train.iteration.get() * train_batchsize

def display_metric(k):
    return k in ["translation-error-discrete", "angle-error-discrete"] or "misclassified-ratio" in k or "invalid-sample" in k or "loss" == k or "prob" in k or "md" in k or "std" in k

first_batch = True
consecutive_invalid_samples = 0
train_translation_errors = []
train_metrics_ema_decay = 0.99
@pl.unpack
def train_step(frames, model_input, loss_input):
    global first_batch, consecutive_invalid_samples
    if first_batch:
        print("Got first training batch")
        first_batch = False

    if not debug_fn is None and schedule.train.iteration.get() % debug_period == 0:
        debug_names = [f"f{f}" for f in range(len(frames))]
        debug_fn(frames, model_input, loss_input, path=log.dir(os.path.join("model-debug", "training", f"it{schedule.train.iteration.get():06}")), names=debug_names)

    if schedule.train.iteration.get() % 5000 == 0:
        save_model()

    metrics, valid_sample = tf_train_step(schedule_multiplier(), learning_rate(), model_input.to_list(), loss_input.to_list())
    metrics["learning-rate"] = [float(learning_rate())] * len(frames)

    if not np.any(valid_sample):
        consecutive_invalid_samples += 1
        if consecutive_invalid_samples >= 500:
            print("Got too many consecutive invalid samples, aborting...")
            sys.exit(-1)
    else:
        consecutive_invalid_samples = 0

    for b in range(len(frames)):
        if valid_sample[b]:
            for k, v in metrics.items():
                mlog_session["metrics"]["training"].OrderedPairs(k, ema_decay=train_metrics_ema_decay)[get_report_step()] = float((v.numpy() if "numpy" in dir(v) else v)[b])
            mlog_session["metrics"]["training"].OrderedPairs("invalid-sample", ema_decay=train_metrics_ema_decay)[get_report_step()] = 0.0 if valid_sample[b] else 1.0
            mlog_session["metrics"]["training"].OrderedPairs("dataset")[get_report_step()] = frames[b].ground_frame.dataset_name # TODO: doesnt work for batchsize > 1

    memory_usage = get_memory_usage()
    mlog_session["metrics"]["training"].OrderedPairs("mem")[get_report_step()] = memory_usage

    if schedule.train.iteration.get() % report_period == 0:
        mlog_session.commit()

    reports = []
    reports.append(("Stage", "Training"))
    reports.append(("Epoch", schedule.train.epoch))
    reports.append(("Batch", schedule.train.iteration.get() % len(train_frame_ids[0])))
    reports.append(("Mem", format_bytes(memory_usage)))
    for k, v in mlog_session["metrics"]["training"].items():
        if display_metric(k):
            reports.append((k, f"{v.ema:.6f}"))
    reports.append(("Train sec/sample", f"{mean_train_time.result().numpy():.3f}"))
    if len(val_datasets) > 0:
        reports.append(("Val sec/epoch", f"{mean_val_time.result().numpy():.1f}"))
    for name, fill in train_data_pipeline_state:
        reports.append((f"Q.{name}", f"{100.0 * fill():.3f}"))
    if len(val_datasets) > 0:
        reports.append(("Next val in", schedule.val.epoch.remaining()))
    print(columnar([[v for k, v in reports]], [k for k, v in reports], no_borders=True))
schedule.train.iteration.subscribe(train_step)

def tf_train_step(schedule_multiplier, learning_rate, model_input, loss_input):
    model_input = georeg.model.io.ModelInput(*model_input)
    loss_input = georeg.model.io.LossInput(*loss_input)
    with tf.GradientTape() as tape:
        loss, metrics, valid_sample = loss_fn(model_input, loss_input, schedule_multiplier, training=True)
        loss = tf.where(valid_sample, loss, 0.0)
        metrics["loss"] = loss
        assert len(loss.shape) == 1
        loss = tf.reduce_mean(loss)

    if tf.reduce_any(valid_sample):
        tf.debugging.assert_all_finite(loss, message="Got non-finite loss")

        gradients = tape.gradient(loss, trainable_variables)

        # all_gradients = tf.math.abs(tf.concat([tf.reshape(g, [-1]) for g in gradients if not g is None], axis=0))
        # print("Min abs gradient=", tf.reduce_min(all_gradients))
        # print("Max abs gradient=", tf.reduce_max(all_gradients))

        all_finite_gradients = True
        for gradient, variable in reversed(list(zip(gradients, trainable_variables))):
            if gradient is None:
                print(f"WARNING: Got None gradient for variable {variable.name}")
                pass
            else:
                all_finite_gradients = tf.math.logical_and(all_finite_gradients, tf.math.reduce_all(tf.math.is_finite(gradient)))
                # tf.debugging.assert_all_finite(gradient, "Non-finite gradient for " + variable.name)

        valid_sample = tf.math.logical_and(valid_sample, all_finite_gradients)

        if tf.reduce_all(valid_sample):
            # Remove None gradients
            # gradients, trainable_variables = tuple(zip(*[(k, v) for k, v in zip(gradients, trainable_variables) if not k is None]))

            # Gradient clipping
            # gradients, global_norm = tf.clip_by_global_norm(gradients, 5.0)
            # metrics["global-gradient-norm"] = global_norm

            pairs = [(g * s, v) for g, s, v in zip(gradients, layer_schedule_multipliers, trainable_variables) if not g is None]
            optimizer.learning_rate.assign(learning_rate)
            optimizer.apply_gradients(pairs)

            # Weight decay
            for v, w, s in zip(trainable_variables, layer_weight_decay, layer_schedule_multipliers):
                if w > 0:
                    v.assign(v * (1 - w * learning_rate * s))

            # Update model weights exponential moving average
            if not model_ema is None:
                model_ema.apply(trainable_variables)
    # else:
    #     valid_sample = False

    return metrics, valid_sample
if not args.debug:
    tf_train_step = tf.function(tf_train_step, input_signature=
          [tf.TensorSpec((), dtype="float32", name="schedule_multiplier"), tf.TensorSpec((), dtype="float32", name="learning_rate")]
        + [tuple(georeg.model.io.ModelInput.tf_signature(train_points_num))]
        + [tuple(georeg.model.io.LossInput.tf_signature())]
    )





if len(val_datasets) > 0:
    val_debug_frame_names = []
    val_counter_to_next_new_debug_frame = 0
    num_val_debug_frames = 8
    @pl.unpack
    def val_step(frames, model_input, loss_input):
        global val_counter_to_next_new_debug_frame
        if not debug_fn is None:
            debug_frames = []
            for frame in frames:
                if frame.name in val_debug_frame_names:
                    debug_frames.append(frame)
                elif len(val_debug_frame_names) < num_val_debug_frames:
                    if val_counter_to_next_new_debug_frame == 0:
                        val_debug_frame_names.append(frame.name)
                        debug_frames.append(frame)
                        val_counter_to_next_new_debug_frame = int(schedule.val.samples_per_epoch // num_val_debug_frames)
                    else:
                        val_counter_to_next_new_debug_frame -= 1

            if len(debug_frames) > 0:
                debug_fn(
                    debug_frames,
                    model_input=val_frames_to_model_input(debug_frames),
                    loss_input=val_frames_to_loss_input(debug_frames),
                    path=log.dir(os.path.join("model-debug", "validation", f"it{schedule.train.iteration.get():06}")),
                    names=[f"f{val_debug_frame_names.index(frame.name):03}" for frame in debug_frames],
                )

        metrics, valid_sample = tf_val_step(schedule_multiplier(), model_input.to_list(), loss_input.to_list())
        for b in range(len(frames)):
            if valid_sample[b]:
                for k, v in metrics.items():
                    mlog_session["metrics"]["validation"].OrderedPairs(k, mlog.client.session.List)[get_report_step()].add(float(v[b].numpy() if "numpy" in dir(v) else v[b]))
            mlog_session["metrics"]["validation"].OrderedPairs("invalid-sample", mlog.client.session.List)[get_report_step()].add(0.0 if valid_sample[b] else 1.0)

        memory_usage = get_memory_usage()
        mlog_session["metrics"]["validation"].OrderedPairs("mem", mlog.client.session.List)[get_report_step()].add(memory_usage)

        reports = []
        reports.append(("Stage", "Validation"))
        reports.append(("Train Epoch", schedule.train.epoch))
        reports.append(("Val Epoch", schedule.val.epoch.get()))
        reports.append(("Batch", schedule.val.iteration.get() % schedule.val.iterations_per_epoch))
        reports.append(("Mem", format_bytes(memory_usage)))
        for k, v in mlog_session["metrics"]["validation"].items():
            if display_metric(k):
                reports.append((k, f"{v[get_report_step()].mean:.6f}"))
        reports.append(("Train sec/sample", f"{mean_train_time.result().numpy():.3f}"))
        reports.append(("Val sec/epoch", f"{mean_val_time.result().numpy():.1f}"))
        for name, fill in val_data_pipeline_state:
            reports.append((f"Q.{name}", f"{100.0 * fill():.3f}"))
        print(columnar([[v for k, v in reports]], [k for k, v in reports], no_borders=True))
    schedule.val.iteration.subscribe(val_step)

    def tf_val_step(schedule_multiplier, model_input, loss_input):
        model_input = georeg.model.io.ModelInput(*model_input)
        loss_input = georeg.model.io.LossInput(*loss_input)
        loss, metrics, valid_sample = loss_fn(model_input, loss_input, schedule_multiplier, training=False)
        metrics["loss"] = loss
        return metrics, valid_sample
    if not args.debug:
        tf_val_step = tf.function(tf_val_step, input_signature=
              [tf.TensorSpec((), dtype="float32", name="schedule_multiplier")]
            + [tuple(georeg.model.io.ModelInput.tf_signature(val_points_num))]
            + [tuple(georeg.model.io.LossInput.tf_signature())]
        )

    schedule.val.epoch_end.subscribe(mlog_session.commit)

    if not model_ema is None:
        schedule.val.epoch_begin.subscribe(model_ema_swap)
        schedule.val.epoch_end.subscribe(model_ema_swap)





def save_model(force=False):
    model_ema_swap()

    if len(val_datasets) > 0 and (force or (schedule.val.epoch.get() % 1 == 0 or schedule.val.epoch.get() == schedule.val.epochs - 1)):
        model.save_weights(os.path.join(log.dir("checkpoints"), f"model_checkpoint_valepoch{schedule.val.epoch.get():05d}.h5"))
    tf.keras.models.save_model(model, log.dir("saved_model"), include_optimizer=False)

    model_ema_swap()
if len(val_datasets) > 0:
    schedule.val.epoch_end.subscribe(save_model)
if args.debug:
    save_model()

mlog_session.commit()

print("########## TRAINING START ##########")
run()

save_model(force=True)
print("########## TRAINING DONE ##########")
os._exit(0)
