from . import backbone
from . import project

import tensorflow as tf
import numpy as np

def softplus(x, temperature=None): # Numerically stable for large x
    if temperature is None:
        return tf.math.log1p(tf.math.exp(-tf.math.abs(x))) + tf.math.maximum(x, 0)
    else:
        return temperature * softplus(x / temperature)

def softmax_x(x, axis=-1):
    return tf.math.divide_no_nan(x, tf.reduce_sum(x, axis=axis, keepdims=True))

def softmax_exp(x, axis=-1):
    return softmax_x(tf.math.exp(x - tf.reduce_max(x, axis=axis, keepdims=True)), axis=axis)
softmax = softmax_exp

def softmax_softplus(x, axis=-1):
    return softmax_x(softplus(x), axis=axis)

def get_inferred_value(n):
    if isinstance(n, int):
        return n
    elif "__dict__" in dir(n) and "_inferred_value" in vars(n) and not n._inferred_value is None and not n._inferred_value[0] is None:
        return int(n._inferred_value[0])
    else:
        return None

def pad_matrix(m, pad=None, rank=None):
    if (pad is None) == (rank is None):
        raise ValueError("Expected either pad or rank")
    if not rank is None and get_inferred_value(tf.shape(m)[-1]) == rank and get_inferred_value(tf.shape(m)[-2]) == rank:
        return m
    if not rank is None:
        paddings = [[0, 0] for _ in range(len(m.shape) - 2)] + [[0, rank - tf.shape(m)[-2]], [0, rank - tf.shape(m)[-1]]]
    else:
        assert not pad is None
        paddings = [[0, 0] for _ in range(len(m.shape) - 2)] + [[0, pad], [0, pad]]
        rank = tf.shape(m)[-1] + pad
    mask = tf.pad(tf.ones(tf.shape(m), dtype="int32"), paddings=paddings) == 1
    foreground = tf.pad(m, paddings=paddings)
    background = tf.eye(rank, dtype=m.dtype)
    for _ in range(len(m.shape) - 2):
        background = background[tf.newaxis]
    return tf.where(mask, foreground, background)

def rescale(image, min=None, max=None, epsilon=1e-7, use_tf=None, dtype="float32"):
    if use_tf is None:
        use_tf = tf.is_tensor(image)
    if use_tf:
        image = tf.cast(image, dtype)
        if min is None:
            min = tf.reduce_min(image)
        if max is None:
            max = tf.reduce_max(image)
        image = tf.clip_by_value(image, min, max)
        image = (image - min) / (max - min + epsilon)
        return image
    else:
        image = image.astype(dtype)
        if min is None:
            min = np.amin(image)
        if max is None:
            max = np.amax(image)
        image = np.clip(image, min, max)
        image = (image - min) / (max - min + epsilon)
        return image
