#!/usr/bin/env python3

import argparse
import os.path as path
import pickle
import time

import numpy as np
import tensorflow as tf

from sarnn.simulation import evaluate
from snn.utils import (
    cleanup_tmp, evaluate_ann, initialize_mpi, load_model, load_ann_dataset, load_snn_dataset,
    register_handlers)


def main(args):
    register_handlers()
    np.random.seed(0)
    tf.random.set_seed(0)

    name = path.splitext(path.basename(args.snn_filename))[0]
    comm, n_ranks, rank = initialize_mpi(ranks_per_node=args.ranks_per_node)
    dataset = name.split("_")[1]

    # ANN accuracy is used to compute derived metrics
    if rank == 0 and args.ann_filename is not None:
        data, steps = load_ann_dataset(dataset, "test", n_items=args.n_items)
        ann = load_model(args.ann_filename)
        acc_ann = evaluate_ann(ann, data, steps)[1]
    else:
        acc_ann = None

    (x_snn, y_snn), n_items = load_snn_dataset(dataset, "test", n_items=args.n_items)
    snn = load_model(args.snn_filename)

    if args.v_initial_filename is not None:
        with open(args.v_initial_filename, "rb") as pickle_file:
            v_initial = pickle.load(pickle_file)[-1]["v_initial"]
    else:
        v_initial = [args.v_initial]

    comm.Barrier()
    t_1 = time.time()
    results = evaluate(
        snn, x_snn, y_snn,
        acc_ann=acc_ann,
        threshold=args.threshold,
        v_initial=v_initial,
        n_time_chunks=args.n_time_chunks,
        poisson=args.poisson,
        decay=args.decay,
        clamp_rate=args.clamp_rate,
        n_items=n_items,
        mask_class=21 if dataset == "voc" else None,
        input_repeat=args.input_repeat,
        input_squash=args.input_squash)
    comm.Barrier()
    t_2 = time.time()

    if rank != 0:
        return

    headers = []
    values = []
    for header in sorted(results):
        headers.append(header)
        values.append(results[header])
    if args.timer:
        headers.append("elapsed_seconds")
        values.append(t_2 - t_1)
    print(",".join(["name"] + headers))
    print(",".join([name] + ["{:#.6g}".format(value) for value in values]))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluates the performance of an SNN.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "snn_filename",
        help="The filename of the SNN to evaluate. The dataset is "
             "determined by looking at the second underscore-delimited "
             "piece of the filename. For example, the filename "
             "conv_mnist_test.h5 indicates that the MNIST dataset "
             "should be used. See train.py for a list of valid "
             "datasets.")

    parser.add_argument(
        "-a", "--ann-filename",
        help="The filename of the ANN to use to compute the TAC and "
             "PAC metrics.")
    parser.add_argument(
        "-D", "--decay", type=float,
        help="An exponential decay rate to apply to SNN outputs when "
             "determining predictions. Should be a value in the range "
             "[0, 1). Larger values lead to more sluggish output.")
    parser.add_argument(
        "-H", "--threshold", type=float,
        help="The threshold value to use to compute the TTA and PTA "
             "metrics.")
    parser.add_argument(
        "-i", "--v-initial", default=0.5, type=float,
        help="The value to which neuron membrane potentials should be "
             "initialized at the start of each simulation.")
    parser.add_argument(
        "-L", "--clamp-rate", default=0, type=int,
        help="This value, when multiplied by the depth of a layer, "
             "gives the number of time steps each layer waits for its "
             "input to stabilize before beginning membrane potential "
             "updates. This has no effect unless -C/--enable-clamp was "
             "used during conversion.")
    parser.add_argument(
        "-n", "--n-items", default=-1, type=int,
        help="The number of test items to use for evaluation. Negative "
             "to use all items.")
    parser.add_argument(
        "-p", "--poisson", action="store_true",
        help="Generate binary Poisson spike trains as model input.")
    parser.add_argument(
        "-r", "--ranks-per-node", type=int,
        help="If not None, the GPUs visible to this process will be "
             "set as those satisfying "
             "gpu_id %% ranks-per-node == rank. This is used to keep "
             "multiple MPI ranks on the same node from attempting to "
             "use the same GPU. This should only be used in an MPI "
             "context.")
    parser.add_argument(
        "-q", "--input-repeat", default=1, type=int,
        help="The number of SNN time steps for which each input frame "
             "should be repeated. This only makes sense with Poisson "
             "input.")
    parser.add_argument(
        "-Q", "--input-squash", default=1, type=int,
        help="The number of input time steps over which each SNN input "
             "frame should be averaged. This only makes sense with "
             "Poisson input.")
    parser.add_argument(
        "-t", "--n-time-chunks", default=5, type=int,
        help="The number of time chunks for which the simulation "
             "should be run. The number of steps per time chunk is "
             "specified during conversion using the "
             "-c/--time-chunk-size option.")
    parser.add_argument(
        "-T", "--timer", action="store_true",
        help="Print the time taken by evaluation.")
    parser.add_argument(
        "-V", "--v-initial-filename",
        help="A pickle filename containing a layer-specific or "
             "neuron-specific membrane potential initialization "
             "strategy. This is typically the results file written by "
             "optimize.py when run with the -V/--optimize-v-initial "
             "flag. This overrides any value given for -i/--v-initial.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
