import os
import os.path as path
import signal
import sys

import numpy as np
import tensorflow as tf
from mpi4py import MPI
from tensorflow.keras.datasets import mnist, cifar10, cifar100
from tensorflow.keras.utils import CustomObjectScope, to_categorical

import sarnn.utils
from snn.imagenet import load_imagenet, IMAGENET_BATCH_SIZE, IMAGENET_PREFETCH
from snn.voc import (
    MeanAccuracy, MeanIOU, PixelAccuracy, WeightedMeanIOU, VOC_BATCH_SIZE, VOC_PREFETCH, load_voc,
    voc_loss)

CUSTOM_OBJECTS = {
    "voc_loss":        voc_loss,
    "MeanAccuracy":    MeanAccuracy,
    "MeanIOU":         MeanIOU,
    "PixelAccuracy":   PixelAccuracy,
    "WeightedMeanIOU": WeightedMeanIOU
}


def cleanup_tmp():
    if MPI.COMM_WORLD.Get_rank() != 0:
        return
    tmp_dir = "/tmp"
    if path.isdir(tmp_dir):
        for item in os.listdir(tmp_dir):
            if item.startswith("tmp") and item.endswith(".py"):
                item_path = path.join(tmp_dir, item)
                if os.stat(item_path).st_uid == os.getuid():
                    os.remove(item_path)


def evaluate_ann(ann, data, steps):
    tf_data = isinstance(data, tf.data.Dataset)
    return ann.evaluate(
        x=data if tf_data else data[0],
        y=None if tf_data else data[1],
        steps=steps,
        verbose=0)


def ensure_exists(dirname):
    if not path.isdir(dirname):
        os.makedirs(dirname)


def initialize_mpi(ranks_per_node=None):
    comm = MPI.COMM_WORLD
    n_ranks = comm.Get_size()
    rank = comm.Get_rank()

    gpus = tf.config.list_physical_devices("GPU")
    print("info: MPI rank {} got {} GPUs".format(rank, len(gpus)), file=sys.stderr)

    if ranks_per_node is not None:
        if len(gpus) % ranks_per_node != 0 and rank == 0:
            print("warning: {} GPUs cannot be evenly divided among {} "
                  "ranks; this will likely result in load imbalance"
                  .format(len(gpus), ranks_per_node), file=sys.stderr)
        tf.config.set_visible_devices(gpus[rank % ranks_per_node::ranks_per_node], "GPU")

    return comm, n_ranks, rank


def load_ann_dataset(dataset, split, n_items=-1, val_split=True, batch=True):
    if dataset == "voc":
        data, n_all = load_voc(split, val_split=val_split)
        return _prepare_tf(data, n_all, n_items, batch, VOC_BATCH_SIZE, VOC_PREFETCH)
    elif dataset == "imagenet":
        data, n_all = load_imagenet(split, val_split=val_split)
        return _prepare_tf(data, n_all, n_items, batch, IMAGENET_BATCH_SIZE, IMAGENET_PREFETCH)
    else:
        data = _load_np_dataset(dataset, val_split)
        if split == "train":
            return _take_np(data[0], n_items)
        elif split == "test":
            return _take_np(data[1], n_items)
        else:
            return _take_np(data[2], n_items)


def load_model(filename):
    with CustomObjectScope(sarnn.utils.CUSTOM_OBJECTS, CUSTOM_OBJECTS):
        return tf.keras.models.load_model(filename)


def load_snn_dataset(dataset, split, n_items=-1, val_split=True):
    data, n_steps = load_ann_dataset(
        dataset, split, n_items=n_items, val_split=val_split, batch=False)
    if isinstance(data, tf.data.Dataset):
        data_x = data.map(lambda x, y: x)
        data_y = data.map(lambda x, y: y)
        return (data_x, data_y), n_steps
    else:
        return data, n_steps


def register_handlers():
    signal.signal(signal.SIGTERM, _sigterm_handler)


def _load_np_dataset(name, val_split, seed=0):
    assert name in ["mnist", "cifar10", "cifar100"]

    if name == "mnist":
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train = np.expand_dims(x_train, -1)
        x_test = np.expand_dims(x_test, -1)
    elif name == "cifar10":
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    else:
        (x_train, y_train), (x_test, y_test) = cifar100.load_data()

    x_train, x_test = _norm(x_train), _norm(x_test)
    y_train, y_test = to_categorical(y_train), to_categorical(y_test)

    if val_split:
        np.random.seed(seed)
        shuffle = np.random.permutation(x_train.shape[0])
        x_train, y_train = x_train[shuffle], y_train[shuffle]
        x_train, x_val = x_train[:-10000], x_train[-10000:]
        y_train, y_val = y_train[:-10000], y_train[-10000:]
        return (x_train, y_train), (x_test, y_test), (x_val, y_val)
    else:
        return (x_train, y_train), (x_test, y_test)


def _norm(x):
    return x.astype(np.float32) / 255.0


def _prepare_tf(data, n_all, n_items, batch, batch_size, prefetch):
    if n_items >= 0:
        data = data.take(n_items)
    n_items = n_all if n_items < 0 else n_items
    data = data.prefetch(prefetch)
    if batch:
        return data.batch(batch_size), int(np.ceil(n_items / batch_size))
    else:
        return data, n_items


def _sigterm_handler(signum, frame):
    cleanup_tmp()
    sys.exit("Received SIGTERM")


def _take_np(data, n_items):
    if n_items >= 0:
        return (data[0][:n_items], data[1][:n_items]), n_items
    else:
        return data, data[0].shape[0]
