#!/usr/bin/env python3

import argparse
import os.path as path

import tensorflow as tf
from skimage.io import imsave

from snn.utils import cleanup_tmp, ensure_exists, load_model, load_ann_dataset, register_handlers
from snn.voc import VOC_CLASSES, voc_blend, voc_trim, voc_trim_size


def main(args):
    register_handlers()

    true_dir = path.join(args.out_dir, "true")
    ensure_exists(true_dir)
    pred_dir = path.join(args.out_dir, "pred")
    ensure_exists(pred_dir)

    iterator = iter(load_ann_dataset("voc", "test")[0])
    ann = load_model(args.model)

    for i in range(args.n_items):
        x, y_true = next(iterator)
        y_pred = ann(tf.expand_dims(x, axis=0))
        y_pred = tf.squeeze(y_pred, axis=0)
        y_pred = tf.argmax(y_pred, axis=-1)
        y_pred = tf.one_hot(y_pred, VOC_CLASSES)

        size = voc_trim_size(y_true)
        x = voc_trim(x, size)
        y_true = voc_trim(y_true, size)
        y_pred = voc_trim(y_pred, size)

        imsave(path.join(true_dir, "{:05d}.png".format(i)), voc_blend(x, y_true, alpha=args.alpha))
        imsave(path.join(pred_dir, "{:05d}.png".format(i)), voc_blend(x, y_pred, alpha=args.alpha))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Visualizes the predictions of a VOC model.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "model",
        help="The filename of the segmentation model.")

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-a", "--alpha", default=0.5, type=float,
        help="The fraction of the output image blend which should be "
             "comprised of the segmentation results.")
    parser.add_argument(
        "-n", "--n-items", default=20, type=int,
        help="The number of examples to process and save.")
    parser.add_argument(
        "-O", "--out-dir", default=path.join("results", "fcn32_voc"),
        help="The directory where the results should be saved. This is "
             "created if it does not already exist.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
