#!/usr/bin/env python3

import argparse
import os.path as path
import time

import tensorflow as tf

from sarnn.utils import compute_sparsity, count_model_synapses
from snn.utils import cleanup_tmp, evaluate_ann, load_ann_dataset, load_model, register_handlers


def main(args):
    register_handlers()

    name = path.splitext(path.basename(args.ann_filename))[0]

    data, steps = load_ann_dataset(name.split("_")[1], "test", n_items=args.n_items)
    ann = load_model(args.ann_filename)

    t_1 = time.time()
    results = evaluate_ann(ann, data, steps)
    t_2 = time.time()

    tf_data = isinstance(data, tf.data.Dataset)
    activation_sparsity = compute_sparsity(ann, data if tf_data else data[0])
    nonzero_synapses = count_model_synapses(ann, sparse_counting=True, epsilon=args.epsilon)
    synapse_sparsity = 1.0 - nonzero_synapses / count_model_synapses(ann)
    combined_sparsity = 1.0 - (1.0 - activation_sparsity) * (1.0 - synapse_sparsity)

    headers = ann.metrics_names + ["activation_sparsity", "synapse_sparsity", "combined_sparsity"]
    values = results + [activation_sparsity, synapse_sparsity, combined_sparsity]
    if args.timer:
        headers.append("elapsed_seconds")
        values.append(t_2 - t_1)
    print(",".join(["name"] + headers))
    print(",".join([name] + ["{:#.6g}".format(value) for value in values]))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluates the performance and sparsity of an ANN.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "ann_filename",
        help="The filename of the ANN to evaluate. The dataset is "
             "determined by looking at the second underscore-delimited "
             "piece of the filename. For example, the filename "
             "conv_mnist_test.h5 indicates that the MNIST dataset "
             "should be used. See train.py for a list of valid "
             "datasets.")

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-e", "--epsilon", type=float,
        help="Weights whose absolute value are less than epsilon are "
             "considered to be zero when counting nonzero synapses.")
    parser.add_argument(
        "-n", "--n-items", default=-1, type=int,
        help="The number of test items to use for evaluation. Negative "
             "to use all items.")
    parser.add_argument(
        "-T", "--timer", action="store_true",
        help="Print the time taken by evaluation.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
