#!/usr/bin/env python3

import os
import os.path as path
import pickle

import numpy as np

from sarnn.simulation import evaluate
from snn.utils import load_ann_dataset, load_model

# Don't go all the way to 1.0 because that would be an absurd
# initialization (every neuron firing on the first time step by default)
init_values = np.linspace(0, 0.9, 10)

# MNIST is the floor of the ANN accuracy, CIFAR is the floor minus one
thresholds = {
    "dense_mnist":   0.98,
    "conv_mnist":    0.99,
    "conv_cifar10":  0.88,
    "conv_cifar100": 0.62,
}

results = {}
for name in "dense_mnist", "conv_mnist", "conv_cifar10", "conv_cifar100":
    print((" " + name + " ").center(79, "="))

    snn = load_model(path.join("models", "snn", name + ".h5"))
    dataset = name.split("_")[1]
    (x_test, y_test), _ = load_ann_dataset(dataset, "test")

    ann = load_model(path.join("models", "ann", name + ".h5"))
    acc_ann = ann.evaluate(x_test, y_test, verbose=0)[1]

    # Other experiments use 50 time chunks for CIFAR, but this would be
    # too computationally expensive here (and probably wouldn't add
    # much)
    n_time_chunks = 5 if dataset == "mnist" else 20

    results[name] = []
    for init_value in init_values:
        print("{:.1f}".format(init_value))
        results[name].append(evaluate(
            snn, x_test, y_test,
            acc_ann=acc_ann,
            threshold=thresholds[name],
            v_initial=init_value,
            n_time_chunks=n_time_chunks))

if not path.exists("caches"):
    os.makedirs("caches")
with open(path.join("caches", "optimal_init.p"), "wb") as pickle_file:
    pickle.dump(results, pickle_file)
