#!/usr/bin/env python3

# %%

import os
import os.path as path

import tensorflow as tf
from imageio import mimwrite
from skimage.io import imsave

from sarnn.simulation import simulate
from snn.utils import load_model
from snn.voc import VOC_CLASSES, load_voc, voc_blend, voc_trim, voc_trim_size

# %%

_, (test, n_test) = load_voc()
x_test = test.map(lambda a, b: a)
y_test = test.map(lambda a, b: a)

ann = load_model(path.join("models", "ann", "fcn32_voc.h5"))
snn = load_model(path.join("models", "snn", "fcn32_voc_n.h5"))

# %%

results = simulate(
    snn, x_test, y_test,
    v_initial=0.5,
    n_time_chunks=20,
    n_items=50,
    mask_class=22,
    return_predictions=True)

# %%

out_dir = path.join("results", "voc_animation")

true_dir = path.join(out_dir, "true")
if not path.exists(true_dir):
    os.makedirs(true_dir)
pred_ann_dir = path.join(out_dir, "pred_ann")
if not path.exists(pred_ann_dir):
    os.makedirs(pred_ann_dir)
pred_snn_dir = path.join(out_dir, "pred_snn")
if not path.exists(pred_snn_dir):
    os.makedirs(pred_snn_dir)

# %%

test_iter = iter(test)

for i in range(results["predictions"].shape[0]):
    x, y_true = next(test_iter)
    y_pred_ann = tf.argmax(tf.squeeze(ann(tf.expand_dims(x, axis=0)), axis=0), axis=-1)
    y_pred_ann = tf.one_hot(y_pred_ann, VOC_CLASSES)
    y_pred_snn = tf.one_hot(results["predictions"][i], VOC_CLASSES)

    size = voc_trim_size(y_true)
    x = voc_trim(x, size)
    y_true = voc_trim(y_true, size)
    y_pred_ann = voc_trim(y_pred_ann, size)
    y_pred_snn = voc_trim(y_pred_snn, size, start_axis=1)

    imsave(
        path.join(true_dir, "{:05d}.png".format(i)),
        voc_blend(x, y_true))
    imsave(
        path.join(pred_ann_dir, "{:05d}.png".format(i)),
        voc_blend(x, y_pred_ann))
    mimwrite(
        path.join(pred_snn_dir, "{:05d}.gif".format(i)),
        voc_blend(x, y_pred_snn))
