#!/usr/bin/env python3

import os
import os.path as path
import pickle

from sarnn.simulation import simulate
from snn.utils import load_snn_dataset, load_model

results = {}
for name in "dense_mnist", "conv_mnist", "conv_cifar10", "conv_cifar100":
    results[name] = {}
    for version in "base", "final":
        print(name + " " + version)

        if version == "base":
            snn_filename = path.join("models", "snn", name + ".h5")
            v_initial = 0.0
        else:
            snn_filename = path.join("models", "opt", name + "_sparse.h5")
            pickle_filename = path.join("results", "optimize", name + "_sparse.p")
            with open(pickle_filename, "rb") as pickle_file:
                v_initial = pickle.load(pickle_file)[-1]["v_initial"]

        snn = load_model(snn_filename)
        dataset = name.split("_")[1]
        (x, y), _ = load_snn_dataset(dataset, "test")
        results[name][version] = simulate(
            snn, x, y,
            v_initial=v_initial,
            n_time_chunks=5 if dataset == "mnist" else 100)

os.makedirs("results", exist_ok=True)
with open(path.join("results", "accuracy_time.p"), "wb") as pickle_file:
    pickle.dump(results, pickle_file)
