#!/usr/bin/env python3

import os
import os.path as path
import pickle

from sarnn.simulation import simulate
from snn.utils import load_snn_dataset, load_model

results = {}
(x, y), _ = load_snn_dataset("cifar10", "test")

model_base = load_model(path.join("models", "snn", "conv_cifar10.h5"))
results["base"] = simulate(model_base, x, y, v_initial=0.0, n_time_chunks=30)

model_final = load_model(path.join("models", "optimized", "conv_cifar10_sparse.h5"))
with open(path.join("results", "optimize", "conv_cifar10_sparse.p"), "rb") as f:
    v_initial_final = pickle.load(f)[-1]["v_initial"]
results["final"] = simulate(model_final, x, y, v_initial=v_initial_final, n_time_chunks=30)

os.makedirs("results", exist_ok=True)
with open(path.join("results", "curve_improvement.p"), "wb") as f:
    pickle.dump(results, f)
