from os import path as path

import numpy as np
import tensorflow as tf
from tensorflow.keras.metrics import Metric

# Note: VOC_H and VOC_W need to be divisible by
# 32 * VOC_PATCHES * VOC_DOWNSIZE
VOC_H = 512
VOC_W = 512
VOC_BATCH_SIZE = 16
VOC_CLASSES = 22
VOC_DOWNSIZE = 2
VOC_PATCHES_H = 1
VOC_PATCHES_W = 1
VOC_PREFETCH = 4 * VOC_BATCH_SIZE

VOC_N_VAL = 300 * VOC_PATCHES_H * VOC_PATCHES_W


# Definitions used in metrics:
#   tp: True positive
#   ap: Actual positive
#   pp: Predicted positive


class MeanAccuracy(Metric):
    def __init__(self, name="mean_accuracy", n_classes=21, **kwargs):
        super(MeanAccuracy, self).__init__(name=name, **kwargs)
        self.n_classes = n_classes
        self.tp = self.add_weight(name="tp", shape=n_classes, initializer="zeros")
        self.ap = self.add_weight(name="ap", shape=n_classes, initializer="zeros")

    def update_state(self, y_true, logits, sample_weight=None):
        one_hot = tf.one_hot(tf.argmax(logits, axis=-1), self.n_classes)
        correct = one_hot * tf.cast(one_hot == y_true[..., :-1], self.dtype)
        self.tp.assign_add(tf.reduce_sum(correct, axis=(0, 1, 2)))
        self.ap.assign_add(tf.reduce_sum(y_true[..., :-1], axis=(0, 1, 2)))

    def result(self):
        return tf.reduce_mean(self.tp / self.ap)

    def reset_states(self):
        self.tp.assign(tf.zeros(self.n_classes))
        self.ap.assign(tf.zeros(self.n_classes))


class MeanIOU(Metric):
    def __init__(self, name="mean_iou", n_classes=21, **kwargs):
        super(MeanIOU, self).__init__(name=name, **kwargs)
        self.n_classes = n_classes
        self.tp = self.add_weight(name="tp", shape=n_classes, initializer="zeros")
        self.ap = self.add_weight(name="ap", shape=n_classes, initializer="zeros")
        self.pp = self.add_weight(name="pp", shape=n_classes, initializer="zeros")

    def update_state(self, y_true, logits, sample_weight=None):
        one_hot = tf.one_hot(tf.argmax(logits, axis=-1), self.n_classes)
        correct = one_hot * tf.cast(one_hot == y_true[..., :-1], self.dtype)
        self.tp.assign_add(tf.reduce_sum(correct, axis=(0, 1, 2)))
        self.ap.assign_add(tf.reduce_sum(y_true[..., :-1], axis=(0, 1, 2)))
        self.pp.assign_add(tf.reduce_sum(one_hot, axis=(0, 1, 2)))

    def result(self):
        return tf.reduce_mean(self.tp / (self.ap + self.pp - self.tp))

    def reset_states(self):
        self.tp.assign(tf.zeros(self.n_classes))
        self.ap.assign(tf.zeros(self.n_classes))
        self.pp.assign(tf.zeros(self.n_classes))


class PixelAccuracy(Metric):
    def __init__(self, name="pixel_accuracy", **kwargs):
        super(PixelAccuracy, self).__init__(name=name, **kwargs)
        self.tp = self.add_weight(name="tp", initializer="zeros")
        self.ap = self.add_weight(name="ap", initializer="zeros")

    def update_state(self, y_true, logits, sample_weight=None):
        correct = tf.argmax(logits, axis=-1) == tf.argmax(y_true, axis=-1)
        correct = tf.cast(correct, self.dtype)
        self.tp.assign_add(tf.reduce_sum(correct))
        self.ap.assign_add(tf.reduce_sum(y_true[..., :-1]))

    def result(self):
        return self.tp / self.ap

    def reset_states(self):
        self.tp.assign(0)
        self.ap.assign(0)


class WeightedMeanIOU(Metric):
    def __init__(self, name="weighted_mean_iou", n_classes=21, **kwargs):
        super(WeightedMeanIOU, self).__init__(name=name, **kwargs)
        self.n_classes = n_classes
        self.tp = self.add_weight(name="tp", shape=n_classes, initializer="zeros")
        self.ap = self.add_weight(name="ap", shape=n_classes, initializer="zeros")
        self.pp = self.add_weight(name="pp", shape=n_classes, initializer="zeros")

    def update_state(self, y_true, logits, sample_weight=None):
        one_hot = tf.one_hot(tf.argmax(logits, axis=-1), self.n_classes)
        correct = one_hot * tf.cast(one_hot == y_true[..., :-1], self.dtype)
        self.tp.assign_add(tf.reduce_sum(correct, axis=(0, 1, 2)))
        self.ap.assign_add(tf.reduce_sum(y_true[..., :-1], axis=(0, 1, 2)))
        self.pp.assign_add(tf.reduce_sum(one_hot, axis=(0, 1, 2)))

    def result(self):
        return tf.reduce_sum(
            self.ap * self.tp / (self.ap + self.pp - self.tp)) / tf.reduce_sum(self.ap)

    def reset_states(self):
        self.tp.assign(tf.zeros(self.n_classes))
        self.ap.assign(tf.zeros(self.n_classes))
        self.pp.assign(tf.zeros(self.n_classes))


def load_voc(split, val_split=True):
    if split in ["train", "val"]:
        full, n_full = _load_voc_split("train")
        if not val_split:
            if split == "val":
                raise ValueError('requested split "val", but val_split=False')
            return full, n_full
        else:
            n_train = n_full - VOC_N_VAL
            if split == "train":
                return full.take(n_train), n_train
            else:
                return full.skip(n_train), VOC_N_VAL
    elif split == "test":
        return _load_voc_split("val")
    else:
        raise ValueError('split must be "train", "val", or "test"')


def voc_blend(x, y, alpha=0.5):
    cmap = voc_cmap(VOC_CLASSES)
    mapped = y @ cmap
    blend = (1.0 - alpha) * 255.0 * x + alpha * mapped
    if isinstance(blend, tf.Tensor):
        blend = blend.numpy()
    return blend.astype(np.uint8)


# From https://gist.github.com/wllhf/a4533e0adebe57e3ed06d4b50c8419ae
# with a few minor modifications
def voc_cmap(n_classes=256):
    cmap = np.zeros((n_classes, 3), dtype=np.uint8)
    for i in range(n_classes):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | (_get_bit(c, 0) << 7 - j)
            g = g | (_get_bit(c, 1) << 7 - j)
            b = b | (_get_bit(c, 2) << 7 - j)
            c = c >> 3
        cmap[i] = np.array([r, g, b])
    return cmap


def voc_item_size():
    return VOC_H // VOC_DOWNSIZE // VOC_PATCHES_H, VOC_W // VOC_DOWNSIZE // VOC_PATCHES_W


def voc_loss(y_true, logits):
    loss = tf.nn.softmax_cross_entropy_with_logits(y_true[..., :-1], logits)
    mask = tf.cast(tf.argmax(y_true, axis=-1) != 21, loss.dtype)
    return mask * loss


def voc_trim(image, size, start_axis=0):
    for axis in 0, 1:
        image = np.take(image, range(0, size[axis]), axis=axis + start_axis)
    return image


def voc_trim_size(y_true):
    size = []
    for axis in 0, 1:
        for i in range(y_true.shape[axis] - 1, 0, -1):
            all_null = np.all(np.take(y_true, i, axis=axis)[..., VOC_CLASSES - 1] == 1)
            if not all_null:
                size.append(i + 1)
                break
    return tuple(size)


def _create_voc_patches(tensor, n_channels):
    tensor = tf.convert_to_tensor(tf.split(tensor, VOC_PATCHES_H, axis=0))
    tensor = tf.convert_to_tensor(tf.split(tensor, VOC_PATCHES_W, axis=2))
    tensor = tf.reshape(tensor, (
        VOC_PATCHES_H * VOC_PATCHES_W,
        VOC_H // VOC_DOWNSIZE // VOC_PATCHES_H, VOC_W // VOC_DOWNSIZE // VOC_PATCHES_W,
        n_channels))
    return tensor


def _get_bit(byte, k):
    return (byte & (1 << k)) != 0


def _load_voc_split(split):
    split_filename = path.join("voc", "ImageSets", "Segmentation", split + ".txt")
    with open(split_filename, "r") as split_file:
        lines = split_file.read().splitlines()

    image_prefix = path.join("voc", "JPEGImagesPrepared")
    label_prefix = path.join("voc", "SegmentationClassPrepared")
    image_filenames = map(lambda x: path.join(image_prefix, x + ".jpg"), lines)
    label_filenames = map(lambda x: path.join(label_prefix, x + ".png"), lines)
    image_filenames = list(filter(lambda x: path.exists(x), image_filenames))
    label_filenames = list(filter(lambda x: path.exists(x), label_filenames))
    assert len(image_filenames) == len(label_filenames)

    dataset = tf.data.Dataset.from_tensor_slices((image_filenames, label_filenames))
    dataset = dataset.map(_preprocess_voc)
    dataset = dataset.map(
        lambda x, y: (_create_voc_patches(x, 3), _create_voc_patches(y, VOC_CLASSES)))
    dataset = dataset.unbatch()

    return dataset, len(image_filenames) * VOC_PATCHES_H * VOC_PATCHES_W


def _preprocess_voc(image_filename, label_filename):
    image = tf.io.decode_jpeg(tf.io.read_file(image_filename))
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, (VOC_H // VOC_DOWNSIZE, VOC_W // VOC_DOWNSIZE))

    # 0: background, 1 to VOC_CLASSES-2: objects, VOC_CLASSES-1: void
    label = tf.io.decode_png(tf.io.read_file(label_filename))
    mask_keep = tf.cast(label != 255, tf.uint8)
    mask_void = tf.cast(label == 255, tf.uint8)
    label = mask_keep * label + mask_void * (VOC_CLASSES - 1)
    label = tf.image.resize(
        label, (VOC_H // VOC_DOWNSIZE, VOC_W // VOC_DOWNSIZE), method="nearest")
    label = tf.one_hot(tf.squeeze(label, axis=-1), VOC_CLASSES)

    return image, label
