#!/usr/bin/env python3

import argparse
import os.path as path

import tensorflow as tf

from sarnn.conversion import convert
from snn.utils import cleanup_tmp, ensure_exists, load_ann_dataset, load_model, register_handlers


def main(args):
    register_handlers()

    name = path.splitext(path.basename(args.ann_filename))[0]
    if args.id is not None:
        name += "_" + args.id

    if not path.exists(args.ann_filename):
        print("Could not find {}.".format(args.ann_filename))
        print("Unable to proceed.")
        return

    ensure_exists(args.snn_dir)
    snn_filename = path.join(args.snn_dir, name + ".h5")
    if path.exists(snn_filename):
        print("File {} already exists.".format(snn_filename))
        print("Delete it to convert again.")
        return

    data, _ = load_ann_dataset(name.split("_")[1], "train", n_items=args.n_items)
    ann = load_model(args.ann_filename)

    tf_data = isinstance(data, tf.data.Dataset)
    snn = convert(
        ann, data if tf_data else data[0],
        percentile=args.percentile,
        layer_wise=not args.neuron_wise,
        recompute=args.recompute,
        batch_size=args.batch_size,
        time_chunk_size=args.time_chunk_size,
        spiking_input=True,
        sparse_tracking=args.sparse_tracking,
        track_spikes=True,
        enable_clamp=args.enable_clamp,
        reset_mechanism=args.reset_mechanism,
        t_refrac=args.t_refrac,
        buffer_dv=args.buffer_dv)

    snn.save(snn_filename)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Converts an ANN into an SNN.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "ann_filename",
        help="The filename of the ANN to convert. The dataset is "
             "determined by looking at the second underscore-delimited "
             "piece of the filename. For example, the filename "
             "conv_mnist_test.h5 indicates that the MNIST dataset "
             "should be used. See train.py for a list of valid "
             "datasets.")

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-I", "--id",
        help="A unique identifier to append to the name of the "
             "converted model.")
    parser.add_argument(
        "-S", "--snn-dir", default=path.join("models", "snn"),
        help="The directory where the SNN should be saved. This is "
             "created if it does not already exist.")

    parser.add_argument(
        "-b", "--batch-size", default=100, type=int,
        help="The batch size to give the SNN.")
    parser.add_argument(
        "-B", "--buffer-dv", action="store_true",
        help="Cause spikes to take one timestep to propagate to "
             "downstream neurons.")
    parser.add_argument(
        "-c", "--time-chunk-size", default=10, type=int,
        help="The temporal chunk size to give the SNN.")
    parser.add_argument(
        "-C", "--enable-clamp", action="store_true",
        help="Enable membrane potential clamping in the SNN. This is a "
             "technique proposed by Rueckauer et al. which may improve "
             "convergence speed.")
    parser.add_argument(
        "-n", "--n-items", default=-1, type=int,
        help="The number of training items to use for normalization. "
             "Negative to use all items.")
    parser.add_argument(
        "-N", "--neuron-wise", action="store_true",
        help="Normalize neurons individually instead of layer-wise.")
    parser.add_argument(
        "-p", "--percentile", default=99.0, type=float,
        help="The activation percentile to use during normalization.")
    parser.add_argument(
        "-r", "--recompute", action="store_true",
        help="Recompute each layer's activations from scratch during "
             "normalization. This should be used if normalization "
             "exhausts the CPU memory.")
    parser.add_argument(
        "-R", "--reset-mechanism",
        default="subtract", choices=["subtract", "zero"],
        help='The post-spike membrane reset mechanism; can be either '
             '"subtract" for reset by subtraction or "zero" for reset '
             'to zero.')
    parser.add_argument(
        "-s", "--sparse-tracking", action="store_true",
        help="Only count spikes if they correspond to a nonzero "
             "synapse weight. Note that this behavior cannot be "
             "changed after conversion.")
    parser.add_argument(
        "-t", "--t-refrac", default=0, type=int,
        help="The duration of the post-spiking refractory period. "
             "During this period the membrane potential is frozen. "
             "This is similar to clamping, but the refractory period "
             "is neuron-specific and not layer-global.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
