#!/usr/bin/env python3

import argparse
import os
import os.path as path

import numpy as np
from skimage.io import imread, imsave

from snn.utils import cleanup_tmp, ensure_exists, register_handlers
from snn.voc import VOC_H, VOC_W, voc_cmap


def main(args):
    register_handlers()

    label_in_dir = path.join(args.dataset_dir, "SegmentationClass")
    image_in_dir = path.join(args.dataset_dir, "JPEGImages")

    label_out_dir = path.join(args.dataset_dir, "SegmentationClassPrepared")
    ensure_exists(label_out_dir)
    image_out_dir = path.join(args.dataset_dir, "JPEGImagesPrepared")
    ensure_exists(image_out_dir)

    cmap = voc_cmap()

    for label_filename in sorted(os.listdir(label_in_dir)):
        label = imread(path.join(label_in_dir, label_filename))[:, :, :-1]
        decoded = np.empty(label.shape[:-1], dtype="uint8")
        for i in range(cmap.shape[0]):
            decoded[np.where(np.all(label == cmap[i], axis=-1))] = i
        label_padded = 255 * np.ones((VOC_H, VOC_W), dtype="uint8")
        label_padded[:decoded.shape[0], :decoded.shape[1]] = decoded
        imsave(path.join(label_out_dir, label_filename), label_padded, check_contrast=False)

        image_filename = label_filename.replace(".png", ".jpg")
        image = imread(path.join(image_in_dir, image_filename))
        image_padded = np.zeros((VOC_H, VOC_W, image.shape[2]), dtype="uint8")
        image_padded[:image.shape[0], :image.shape[1], :] = image
        imsave(path.join(image_out_dir, image_filename), image_padded, check_contrast=False)

        print(label_filename.replace(".png", ""))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Performs one-time preparation of VOC data.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False)

    parser.add_argument(
        "-h", "--help", action="help",
        help="Display this help message and exit.")

    parser.add_argument(
        "-V", "--voc-dir", default="voc",
        help="The directory containing the VOC dataset. New "
             "subdirectories will be added here. Data can be "
             "downloaded from "
             "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar.")

    try:
        main(parser.parse_args())
    finally:
        cleanup_tmp()
