#!/usr/bin/env python3

import argparse
import os
import os.path as path
from datetime import datetime

import tensorflow as tf
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay

from snn.models import *
from snn.utils import (
    cleanup_tmp, ensure_exists, evaluate_ann, load_ann_dataset, load_model, register_handlers)
from snn.voc import MeanAccuracy, MeanIOU, PixelAccuracy, WeightedMeanIOU, voc_loss


def main(args):
    register_handlers()
    np.random.seed(0)
    tf.random.set_seed(0)

    name = "{}_{}".format(args.model, args.dataset)
    if args.id is not None:
        name += "_" + args.id
    print((" " + name + " ").center(79, "="))

    ensure_exists(args.ann_dir)
    ann_filename = path.join(args.ann_dir, name + ".h5")
    if path.exists(ann_filename):
        print("File {} already exists.".format(ann_filename))
        print("Delete it to train again.")
        return

    # Search for an existing checkpoint
    ensure_exists(args.checkpoint_dir)
    initial_epoch = 0
    checkpoint = None
    for filename in os.listdir(args.checkpoint_dir):
        pieces = path.splitext(filename)[0].split("_")
        if name != "_".join(pieces[:-1]):
            continue
        if pieces[-1] == "best":
            # Don't overwrite any existing best model
            i = 0
            base = path.join(args.checkpoint_dir, "_".join(pieces) + "{}.h5")
            while path.exists(base.format(i)):
                i += 1
            os.rename(base.format(""), base.format(i))
        elif "best" not in pieces[-1]:
            epoch = int(pieces[-1])
            if epoch > initial_epoch:
                initial_epoch = epoch
                checkpoint = filename

    # Load or create the model
    if checkpoint is not None:
        ann = load_model(path.join(args.checkpoint_dir, checkpoint))
        need_compile = False
    else:
        ann = _build_model(args.model,
                           args.dataset,
                           args.beta_penalty,
                           args.l1_synapse_penalty,
                           args.l2_decay)
        if args.weight_file is not None:
            ann.load_weights(args.weight_file)
        need_compile = True

    # Load the dataset (train and validation splits)
    train_data, train_steps = load_ann_dataset(args.dataset, "train")
    val_data, val_steps = load_ann_dataset(args.dataset, "val")

    # Prepare the model for training
    tf_data = isinstance(train_data, tf.data.Dataset)
    if need_compile:
        if args.dataset == "voc":
            metrics = [PixelAccuracy(), MeanAccuracy(), MeanIOU(), WeightedMeanIOU()]
        elif args.dataset in ["cifar100", "imagenet"]:
            metrics = ["accuracy", "top_k_categorical_accuracy"]
        else:
            metrics = ["accuracy"]
        ann.compile(
            optimizer=getattr(tf.keras.optimizers, args.optimizer)(),
            loss=voc_loss if args.dataset == "voc" else "categorical_crossentropy",
            metrics=metrics)
        if len(args.learning_rates) > 1:
            batch_factor = 1 if tf_data else 32  # 32 is the default batch size
            ann.optimizer.learning_rate = PiecewiseConstantDecay(
                [train_steps / batch_factor * x for x in args.learning_rate_boundaries],
                args.learning_rates)
        else:
            ann.optimizer.learning_rate = args.learning_rates[0]

    # Set up training callbacks (checkpointing and TensorBoard)
    best_filename = path.join(args.checkpoint_dir, name + "_best.h5")
    callbacks = [
        ModelCheckpoint(path.join(args.checkpoint_dir, name + "_{epoch:d}.h5")),
        ModelCheckpoint(best_filename, save_best_only=True)]
    if args.tensorboard_dir != "":
        ensure_exists(args.tensorboard_dir)
        now_str = datetime.now().strftime("_%Y-%m-%d_%H-%M-%S")
        callbacks.append(TensorBoard(
            log_dir=path.join(args.tensorboard_dir, name + now_str)))

    ann.fit(x=train_data.repeat() if tf_data else train_data[0],
            y=None if tf_data else train_data[1],
            epochs=args.epochs,
            callbacks=callbacks,
            validation_data=val_data,
            initial_epoch=initial_epoch,
            steps_per_epoch=train_steps if tf_data else None,
            validation_steps=val_steps if tf_data else None)

    # Finalize and clean up
    ensure_exists(args.ann_dir)
    best_loss = np.inf
    for filename in os.listdir(args.checkpoint_dir):
        pieces = path.splitext(filename)[0].split("_")
        if name != "_".join(pieces[:-1]):
            continue
        if "best" in pieces[-1]:
            # Is this the best of the "best" models?
            ann = load_model(path.join(args.checkpoint_dir, filename))
            loss = evaluate_ann(ann, val_data, val_steps)[0]
            if loss < best_loss:
                best_loss = loss
                ann.save(path.join(args.ann_dir, name + ".h5"))
        os.remove(path.join(args.checkpoint_dir, filename))


def _build_model(model, dataset, beta_penalty, l1_synapse_penalty, l2_decay):
    kwargs = {
        "beta_penalty":       beta_penalty,
        "l1_synapse_penalty": l1_synapse_penalty,
        "l2_decay":           l2_decay,
    }
    if model == "conv" and "cifar" in dataset:
        n_classes = int(dataset.replace("cifar", ""))
        return conv_cifar(n_classes, **kwargs)
    elif model == "conv" and dataset == "mnist":
        return conv_mnist(**kwargs)
    elif model == "dense" and dataset == "mnist":
        return dense_mnist(128, 2, **kwargs)
    elif model == "pool" and dataset == "mnist":
        return dense_mnist(128, 2, pool=True, **kwargs)
    elif model == "fcn32" and dataset == "voc":
        return fcn32_voc(**kwargs)
    elif model == "toy" and dataset == "mnist":
        return dense_mnist(16, 8, **kwargs)
    elif model == "mobilenet" and dataset == "imagenet":
        return mobilenet(**kwargs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Trains a conventional ANN.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "model", choices=["toy", "dense", "conv", "fcn32", "pool", "mobilenet"],
        help="The architecture of the model. toy, dense, and pool "
             "should be used with mnist. conv should be used with "
             "mnist, cifar10, or cifar100. fcn32 should be used with "
             "voc. mobilenet should be used with imagenet.")
    parser.add_argument(
        "dataset", choices=["mnist", "cifar10", "cifar100", "imagenet", "voc"],
        help="The dataset to use.")

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-A", "--ann-dir", default=path.join("models", "ann"),
        help="The directory where the final ANN should be saved.")
    parser.add_argument(
        "-C", "--checkpoint-dir", default=path.join("models", "checkpoints"),
        help="The directory for saving and loading checkpoints.")
    parser.add_argument(
        "-I", "--id",
        help="A unique identifier to append to the model name.")
    parser.add_argument(
        "-T", "--tensorboard-dir", default="tensorboard",
        help="The directory where TensorBoard logs should be written. "
             "TensorBoard logging is disabled if this is an empty "
             "string.")
    parser.add_argument(
        "-w", "--weight-file",
        help="A file from which the model weights should be loaded. "
             "This is ignored when training is automatically resumed "
             "from a checkpoint.")

    parser.add_argument(
        "-b", "--beta-penalty", type=float,
        help="The activation sparsity penalty to apply during "
             "training.")
    parser.add_argument(
        "-d", "--l2-decay", type=float,
        help="The amount of L2 weight decay to add to the loss. "
             "Applies only to kernels, not biases.")
    parser.add_argument(
        "-e", "--epochs", default=50, type=int,
        help="The number of training epochs.")
    parser.add_argument(
        "-l", "--learning-rates", nargs="+", default=[1e-3], type=float,
        help="A list of one or more learning rate values.")
    parser.add_argument(
        "-L", "--learning-rate-boundaries", default=[], type=int, nargs="*",
        help="The boundaries (in units of epochs) at which the "
             "learning rate should be changed. Should contain one "
             "fewer value than -l/--learning-rates.")
    parser.add_argument(
        "-o", "--optimizer", default="SGD",
        help="The name of the optimizer (case-sensitive). See the "
             "classes listed in the tf.keras.optimizers documentation "
             "for a list of acceptable values.")
    parser.add_argument(
        "-s", "--l1-synapse-penalty", type=float,
        help="The kernel weight sparsity penalty to apply during "
             "training.")

    try:
        # This strategy splits batches over the available GPUs
        with tf.distribute.MirroredStrategy().scope():
            main(parser.parse_args())
    finally:
        cleanup_tmp()
