#!/usr/bin/env python3

import argparse
import os.path as path
import pickle
import time

import numpy as np
import tensorflow as tf

from sarnn.optimization import ALGORITHMS, optimize
from sarnn.utils import scale
from snn.utils import (
    cleanup_tmp, ensure_exists, evaluate_ann, initialize_mpi, load_model, load_ann_dataset,
    load_snn_dataset, register_handlers)


def main(args):
    register_handlers()
    np.random.seed(0)
    tf.random.set_seed(0)

    name = path.splitext(path.basename(args.snn_filename))[0]
    if args.id is not None:
        name += "_" + args.id
    comm, n_ranks, rank = initialize_mpi(ranks_per_node=args.ranks_per_node)
    dataset = name.split("_")[1]

    for filename in args.snn_filename, args.ann_filename:
        if not path.exists(filename):
            if rank == 0:
                print("Could not find {}.".format(filename))
                print("Unable to proceed.")
            return

    if rank == 0:
        ensure_exists(args.opt_dir)
    opt_filename = path.join(args.opt_dir, name + ".h5")
    if path.exists(opt_filename):
        if rank == 0:
            print("File {} already exists.".format(opt_filename))
            print("Delete it to optimize again.")
        return

    # Note the use of test data for validation (during optimization val
    # data is used only for monitoring, not for model selection)
    (x_opt, y_opt), n_opt = load_snn_dataset(dataset, "train", n_items=args.n_opt, val_split=False)
    (x_val, y_val), n_val = load_snn_dataset(dataset, "test", n_items=args.n_val, val_split=False)
    snn = load_model(args.snn_filename)

    if rank == 0:
        data_opt, steps_opt = load_ann_dataset(dataset, "train", n_items=args.n_opt)
        data_val, steps_val = load_ann_dataset(dataset, "test", n_items=args.n_val)
        ann = load_model(args.ann_filename)
        acc_ann_opt = evaluate_ann(ann, data_opt, steps_opt)[1]
        acc_ann_val = evaluate_ann(ann, data_val, steps_val)[1]
    else:
        acc_ann_opt, acc_ann_val = None, None

    if rank == 0:
        ensure_exists(args.cache_dir)
    cache_filename = path.join(args.cache_dir, name + ".p")
    if rank == 0:
        ensure_exists(args.results_dir)
    results_filename = path.join(args.results_dir, name + ".p")

    comm.Barrier()
    t_1 = time.time()
    results = optimize(
        snn, x_opt, y_opt, acc_ann_opt,
        x_val=x_val,
        y_val=y_val,
        acc_ann_val=acc_ann_val,
        algorithm=args.algorithm,
        granularities=tuple(args.granularities),
        max_iterations=tuple(args.max_iterations),
        global_v_initial=args.global_v_initial,
        optimize_v_initial=args.optimize_v_initial,
        optimize_scales=not args.no_optimize_scales,
        n_time_chunks=args.n_time_chunks,
        poisson=args.poisson,
        decay=args.decay,
        clamp_rate=args.clamp_rate,
        n_opt=n_opt,
        n_val=n_val,
        mask_class=21 if dataset == "voc" else None,
        input_repeat=args.input_repeat,
        input_squash=args.input_squash,
        lambdas=tuple(args.lambdas),
        auto_scale=tuple(bool(term) for term in args.auto_scale),
        verbose=True,
        val_freq=args.val_freq,
        cache_filename=cache_filename,
        cache_freq=10)
    comm.Barrier()
    t_2 = time.time()

    if rank != 0:
        return

    if args.timer:
        print("Elapsed: {:.2f} s".format(t_2 - t_1))

    with open(results_filename, "wb") as results_file:
        pickle.dump(results, results_file)

    if not args.no_optimize_scales:
        scale(snn, (results[-1] if isinstance(results, list) else results)["scales"])
    snn.save(opt_filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Optimizes the firing rates of an SNN.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "snn_filename",
        help="The filename of the SNN to optimize. The dataset is "
             "determined by looking at the second underscore-delimited "
             "piece of the filename. For example, the filename "
             "conv_mnist_test.h5 indicates that the MNIST dataset "
             "should be used. See train.py for a list of valid "
             "datasets.")
    parser.add_argument(
        "ann_filename",
        help="The filename of the ANN to use to compute the TAC and "
             "PAC metrics.")

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-a", "--algorithm", default="SUBPLEX", choices=ALGORITHMS.keys(),
        help="The name of the derivative-free optimization algorithm "
             "to use.")
    parser.add_argument(
        "-C", "--cache-dir", default="caches",
        help="The directory where the optimizer should save cached "
             "function evaluations. This is created if it does not "
             "already exist.")
    parser.add_argument(
        "-D", "--decay", type=float,
        help="An exponential decay rate to apply to SNN outputs when "
             "determining predictions. Should be a value in the range "
             "[0, 1). Larger values lead to more sluggish output.")
    parser.add_argument(
        "-g", "--granularities", default=[3], type=int, nargs="+", choices=[1, 2, 3],
        help="The level(s) of granularity at which the optimization "
             "should be performed. 1=network-wise, 2-layer-wise, "
             "3-neuron-wise.")
    parser.add_argument(
        "-i", "--global-v-initial", default=0.5, type=float,
        help="The initial neuron membrane potential. If "
             "-V/--optimize-v-initial, this is used as the optimizer "
             "initialization.")
    parser.add_argument(
        "-I", "--id",
        help="A unique identifier to append to the name of the "
             "optimized model.")
    parser.add_argument(
        "-l", "--lambdas", default=[1e2, 1e1, 1e2], type=float, nargs=3,
        help="The weight to give to each term in the loss function "
             "(lambda_acc, lambda_lat, lambda_pow).")
    parser.add_argument(
        "-L", "--clamp-rate", default=0, type=int,
        help="This value, when multiplied by the depth of a layer, "
             "gives the number of time steps each layer waits for its "
             "input to stabilize before beginning membrane potential "
             "updates. This has no effect unless -C/--enable-clamp was "
             "used during conversion.")
    parser.add_argument(
        "-m", "--max-iterations", default=[10000], type=int, nargs="+",
        help="The maximum number of optimizer iterations for each "
             "granularity phase. One value should be given for each "
             "value of -g/--granularities.")
    parser.add_argument(
        "-n", "--n-opt", default=-1, type=int,
        help="The number of training items to use for evaluation. "
             "Negative to use all items.")
    parser.add_argument(
        "-N", "--n-val", default=-1, type=int,
        help="The number of test items to use for validation. Negative "
             "to use all items.")
    parser.add_argument(
        "-O", "--opt-dir", default=path.join("models", "optimized"),
        help="The directory where the optimized SNN should be saved. "
             "This is created if it does not already exist.")
    parser.add_argument(
        "-p", "--poisson", action="store_true",
        help="Generate binary Poisson spike trains as model input.")
    parser.add_argument(
        "-r", "--ranks-per-node", type=int,
        help="If not None, the GPUs visible to this process will be "
             "set as those satisfying gpu_id %% ranks-per-node == "
             "rank. This is used to keep multiple MPI ranks on the "
             "same node from attempting to use the same GPU. This "
             "should only be used in an MPI context.")
    parser.add_argument(
        "-q", "--input-repeat", default=1, type=int,
        help="The number of SNN time steps for which each input frame "
             "should be repeated. This only makes sense with Poisson "
             "input.")
    parser.add_argument(
        "-Q", "--input-squash", default=1, type=int,
        help="The number of input time steps over which each SNN input "
             "frame should be averaged. This only makes sense with "
             "Poisson input.")
    parser.add_argument(
        "-R", "--results-dir", default=path.join("results", "optimize"),
        help="The directory where an optimization results summary "
             "should be saved. This is created if it does not already "
             "exist.")
    parser.add_argument(
        "-s", "--auto-scale", default=[0, 0, 0], type=int, nargs=3, choices=[0, 1],
        help="Whether each loss term should be automatically scaled "
             "such that its initial contribution to the loss on the "
             "optimization set is equal to lambdas[i].")
    parser.add_argument(
        "-S", "--no-optimize-scales", action="store_true",
        help="Do not optimize over firing rate scales.  If this is "
             "enabled, the option --optimize-v-initial must also be "
             "enabled (otherwise, there is nothing to optimize over).")
    parser.add_argument(
        "-t", "--n-time-chunks", default=5, type=int,
        help="The number of time chunks for which the simulation "
             "should be run. The number of steps per time chunk is "
             "specified during conversion using the "
             "-c/--time-chunk-size option.")
    parser.add_argument(
        "-T", "--timer", action="store_true",
        help="Print the overall time taken by the optimization.")
    parser.add_argument(
        "-v", "--val-freq", type=int,
        help="The frequency with which performance on the validation "
             "set should be evaluated.")
    parser.add_argument(
        "-V", "--optimize-v-initial", action="store_true",
        help="Whether to optimize initial neuron membrane potentials "
             "in addition to firing rates.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
