# coding=utf-8
# Copyright 2020 The Heteroscedastic Noisy Labels Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Library of methods to compute heteroscedastic classification predictions."""

from __future__ import absolute_import
from __future__ import division

from __future__ import print_function

import abc
import collections

import heteroscedastic_lib_utils as het_lib_utils
import utils

import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp

# probs is the per class probability, log_probs is the per class log probability
# locs are the location parameters of the location-scale distribution
# logits corresponds to the logits arguments which can be used with
# tf.nn.sparse_softmax_cross_entropy_with_logits for multi-class classification
# and tf.nn.sigmoid_cross_entropy_with_logits for binary classification
# predictive_variance is the per class variance of the predictive distribution
MCSoftmaxResult = collections.namedtuple(
    'MCSoftmaxResult', ['probs', 'log_probs', 'locs', 'logits',
                        'predictive_variance'])


class MCSoftmaxOutputLayer(tf.keras.layers.Layer):
  """Base class for MC heteroscesastic output layers."""

  __metaclass__ = abc.ABCMeta

  def __init__(self, num_classes, logit_noise=het_lib_utils.LogitNoise.NORMAL,
               temperature=1.0, train_mc_samples=1000, test_mc_samples=1000,
               max_mc_samples_in_memory=1000, compute_pred_variance=False,
               scale_positive=True, name='MCSoftmaxOutputLayer'):
    """Creates an instance of MCSoftmaxOutputLayer.

    Args:
      num_classes: Integer. Number of classes for classification task.
      logit_noise: LogitNoise instance. The noise distribution
        assumed on the softmax logits. Possible values:
        LogitNoise.NORMAL, LogitNoise.LOGISTIC, LogitNoise.GUMBEL.
      temperature: Float or scalar `Tensor` representing the softmax
        temperature.
      train_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during training.
      test_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during testing/inference.
      max_mc_samples_in_memory: When estimating the predictive distribution a
        `Tensor` of shape [batch_size, max_mc_samples_in_memory, num_classes]
        will be computed. Set max_mc_samples_in_memory as high as possible for
        efficient computation but low enough such that OOM errors are avoided.
      compute_pred_variance: Boolean. Whether to estimate the predictive
        variance. If False the __call__ method will output None for the
        predictive_variance tensor.
      scale_positive: Boolean. If True enforce the condition that the scale
        parameters must all be positive. This must be true when scale represents
        the vector of per class standard deviations. In low rank approximations
        scale represents the matrix which transforms samples from a standard
        normal into samples from multivariate Gaussian and so
        scale_positive = False.
      name: String. The name of the layer used for name scoping.

    Returns:
      MCSoftmaxOutputLayer instance.
    """
    super(MCSoftmaxOutputLayer, self).__init__(name=name)

    self._num_classes = num_classes
    self._logit_noise = logit_noise
    self._temperature = temperature
    self._train_mc_samples = train_mc_samples
    self._test_mc_samples = test_mc_samples
    self._max_mc_samples_in_memory = max_mc_samples_in_memory
    self._compute_pred_variance = compute_pred_variance
    self._scale_positive = scale_positive
    self._name = name

  def _compute_noise_samples(self, scale, num_samples, seed):
    """Utility function to compute the samples of the logit noise.

    Args:
      scale: Tensor of shape
        [batch_size, ..., 1 if num_classes == 2 else num_classes].
        Scale parameters of the distributions to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      Tensor. Logit noise samples of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes].

    Raises:
      ValueError: when logit_noise is Gumbel or Logistic and seed is a Tensor.
    """
    if seed is None or isinstance(seed, int):
      if self._logit_noise == het_lib_utils.LogitNoise.NORMAL:
        dist = tfp.distributions.Normal(loc=tf.zeros_like(scale), scale=scale)
      elif self._logit_noise == het_lib_utils.LogitNoise.LOGISTIC:
        dist = tfp.distributions.Logistic(loc=tf.zeros_like(scale),
                                          scale=scale)
      else:
        dist = tfp.distributions.Gumbel(loc=tf.zeros_like(scale), scale=scale)

      tf.random.set_seed(seed)
      noise_samples = dist.sample(num_samples, seed=seed)
    else:
      seed_delta = (max(self._train_mc_samples, self._test_mc_samples)//
                    self._max_mc_samples_in_memory) + 1
      # avoiding seed collisions over multiple calls to _compute_noise_samples
      if self._logit_noise == het_lib_utils.LogitNoise.NORMAL:
        noise_samples = tf.random.stateless_normal(
            tf.concat([[num_samples], tf.shape(scale)], axis=0),
            [seed, seed + seed_delta],
            mean=tf.zeros_like(tf.expand_dims(scale, axis=0)),
            stddev=tf.expand_dims(scale, axis=0),
            dtype=scale.dtype)
      else:
        raise ValueError('Non integer seeds are only supported for'
                         'LogitNoise.NORMAL')

    # dist.sample(total_mc_samples) returns Tensor of shape
    # [total_mc_samples, batch_size, d], here we reshape to
    # [batch_size, total_mc_samples, d]
    return tf.transpose(
        noise_samples,
        tf.concat([[1, 0], tf.range(2, tf.rank(noise_samples))], 0))

  def _mc_sample_fn(self, use_argmax):
    """Helper function to return correct monte carlo sampling function."""
    if use_argmax:
      return self._compute_hard_mc_samples
    return self._compute_soft_mc_samples

  def _compute_hard_mc_samples(self, locs, scale, num_samples, seed):
    """Utility function to compute hard Monte-Carlo samples (using argmax).

    Args:
      locs: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Location parameters of the
        distributions to be sampled.
      scale: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Scale parameters of the
        distributions to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      Tensor of shape [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes]. All of the MC samples.

    Raises:
      ValueError: when logit_noise is Gumbel or Logistic and seed is set
    """
    locs = tf.expand_dims(locs, axis=1)
    noise_samples = self._compute_noise_samples(scale, num_samples, seed)
    latents = locs + noise_samples
    if self._num_classes == 2:
      probs = tf.math.greater(latents, 0.5)
    else:
      probs = tf.equal(latents, tf.reduce_max(latents, -1, keepdims=True))
    return tf.cast(probs, latents.dtype)

  def _compute_soft_mc_samples(self, locs, scale, num_samples, seed):
    """Utility function to compute soft Monte-Carlo samples (using softmax).

    Args:
      locs: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Location parameters of the
        distributions to be sampled.
      scale: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Scale parameters of the
        distributions to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      Tensor of shape [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes]. All of the MC samples.

    Raises:
      ValueError: when logit_noise is Gumbel or Logistic and seed is set
    """
    locs = tf.expand_dims(locs, axis=1)
    noise_samples = self._compute_noise_samples(scale, num_samples, seed)
    latents = locs + noise_samples
    if self._num_classes == 2:
      return tf.math.sigmoid(latents / self._temperature)
    else:
      return tf.nn.softmax(latents / self._temperature)

  def _compute_predictive_dist(self, locs, scale, total_mc_samples, seed,
                               use_argmax):
    """Utility function to compute the estimated predictive distribution.

    Args:
      locs: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Location parameters of the
        distributions to be sampled.
      scale: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Scale parameters of the
        distributions to be sampled.
      total_mc_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or scalar Tensor initial seed, for seeding the random
        number generator.
      use_argmax: Boolean. Whether to use the softmax or argmax to compute
        the predictive distribution.

    Returns:
      Tupe of (probs_mean, seeds_list, samples_list).  Where probs is a Tensor
      of shape [batch_size, ..., 1 if num_classes == 2 else num_classes] - the
      mean of the MC samples. seeds is a list of Python integer or scalar Tensor
      seeds used to generate the random samples. samples is a list of integers
      containing the number of samples taken in each batch,
      sum(samples) == total_mc_samples.

    Raises:
      ValueError: when logit_noise is Gumbel or Logistic and seed is set
    """
    compute_mc_samples = self._mc_sample_fn(use_argmax)
    seeds_list = []
    samples_list = []
    if total_mc_samples <= self._max_mc_samples_in_memory:
      if self._compute_pred_variance and seed is None:
        seed = het_lib_utils.gen_int_seed()
      probs = tf.reduce_mean(
          compute_mc_samples(locs, scale, total_mc_samples, seed), axis=1)

      seeds_list.append(seed)
      samples_list.append(total_mc_samples)
    else:
      # divide total_mc_samples into batches of samples that fit in memory
      # need (total_mc_samples // self._max_mc_samples_in_memory) batches of
      # size self._max_mc_samples_in_memory and maybe 1 additional batch of size
      # total_mc_samples % self._max_mc_samples_in_memory
      same_sample_batches = total_mc_samples // self._max_mc_samples_in_memory
      num_samples = [self._max_mc_samples_in_memory] * same_sample_batches
      sampling_weights = ([self._max_mc_samples_in_memory /
                           float(total_mc_samples)] * same_sample_batches)
      if total_mc_samples % self._max_mc_samples_in_memory > 0:
        remainder_samples = total_mc_samples % self._max_mc_samples_in_memory
        num_samples.append(remainder_samples)
        sampling_weights.append(remainder_samples / float(total_mc_samples))

      compute_mc_samples = tf.recompute_grad(compute_mc_samples)
      if seed is None:
        seed = het_lib_utils.gen_tensor_seed()
      for i, (sampling_ops, weight) in enumerate(
          zip(num_samples, sampling_weights)):
        seed = seed + 1  # unique seed for each set of samples
        unweighted_probs = compute_mc_samples(locs, scale, sampling_ops, seed)
        if i == 0:
          probs = weight * tf.reduce_mean(unweighted_probs, axis=1)
        else:
          probs = probs + weight * tf.reduce_mean(unweighted_probs, axis=1)

        seeds_list.append(seed)
        samples_list.append(sampling_ops)

    return probs, seeds_list, samples_list

  def _compute_predictive_variance(self, mean, locs, scale, seeds_list,
                                   samples_list, use_argmax):
    """Utility function to compute the per class predictive variance.

    Args:
      mean: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Estimated predictive
        distribution.
      locs: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Location parameters of the
        distributions to be sampled.
      scale: Tensor of shape [batch_size, total_mc_samples, ...,
        1 if num_classes == 2 else num_classes]. Scale parameters of the
        distributions to be sampled.
      seeds_list: List of scalar Tensors for seeding the random number
        generator.
      samples_list: List of Integers. Number of Monte-Carlo samples to take.
      use_argmax: Boolean. Whether to use the softmax or argmax to compute
        the predictive distribution.

    Returns:
      Tensor of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes]. Estimated predictive variance.

    Raises:
      ValueError: when logit_noise is Gumbel or Logistic and seed is a Tensor.
    """
    compute_mc_samples_fn = self._mc_sample_fn(use_argmax)
    mean = tf.expand_dims(mean, axis=1)
    total_samples = float(sum(samples_list))
    for i, (num_samples, seed) in enumerate(zip(samples_list, seeds_list)):
      mc_samples = compute_mc_samples_fn(locs, scale, num_samples, seed)

      weight = num_samples / total_samples
      variance = tf.reduce_mean((mc_samples - mean)**2, axis=1)
      if i == 0:
        total_variance = weight * variance
      else:
        total_variance = total_variance + weight * variance

    return total_variance

  @abc.abstractmethod
  def _compute_loc_param(self, inputs):
    """Computes location parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return

  @abc.abstractmethod
  def _compute_scale_param(self, inputs):
    """Computes scale parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return

  def __call__(self, inputs, training=True, argmax_preds=False, seed=None):
    """Computes predictive and log predictive distribution.

    Uses Monte Carlo estimate of softmax approximation to heteroscedastic model
    to compute predictive distribution. O(mc_samples * num_classes).

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.
      training: Boolean. Whether we are training or not.
      argmax_preds: Boolean. Whether to take the argmax or softmax to compute
        the predicitive distribution.
      seed: Python integer or scalar Tensor for seeding the random number
        generator. Only needs be set if train_mc_samples <
        max_mc_samples_in_memory as this will recompute the Monte-Carlo samples
        on the backward pass to save memory. The seed ensures that the samples
        on the backward pass are the same as on the forward pass. If set the
        seed should be unique to each call to this method. If None, all random
        operations act as if they are unseeded.

    Returns:
      Instance of MCSoftmaxResult. Contains fields for probs, log_probs,
      logits and predictive_variance. For multi-class classification i.e.
      num_classes > 2 logits = log_probs and logits can be used with the
      standard tf.nn.sparse_softmax_cross_entropy_with_logits loss function.
      For binary classification i.e. num_classes = 2, logits represents the
      argument to a sigmoid function that would yield probs
      (logits = inverse_sigmoid(probs)), so logits can be used with the
      tf.nn.sigmoid_cross_entropy_with_logits loss function.

    Raises:
      ValueError if seed is provided but model is running in graph mode.

    """
    # Seed shouldn't be provided in graph mode.
    if not tf.executing_eagerly():
      if seed is not None:
        raise ValueError('Seed should not be provided when running in graph '
                         'mode, but %s was provided.' % seed)
      if argmax_preds:
        raise ValueError('Cannot use argmaxed predictions in graph mode.')
    with tf.name_scope(self._name):
      eps = 1e-10
      locs = self._compute_loc_param(inputs)
      scale = self._compute_scale_param(inputs)
      if self._scale_positive:
        scale = tf.maximum(scale, eps)

      if training:
        total_mc_samples = self._train_mc_samples
      else:
        total_mc_samples = self._test_mc_samples

      use_argmax = argmax_preds and not training

      compute_mc_samples_fn = self._mc_sample_fn(use_argmax)
      probs_mean, seeds_list, samples_list = utils.compute_predictive_dist(
          locs, scale, total_mc_samples, seed, compute_mc_samples_fn,
          self._max_mc_samples_in_memory)

      pred_variance = None
      if self._compute_pred_variance:
        pred_variance = self._compute_predictive_variance(
            probs_mean, locs, scale, seeds_list, samples_list, use_argmax)

      probs_mean = tf.clip_by_value(probs_mean, eps, 1.0)
      log_probs = tf.math.log(probs_mean)

      if self._num_classes == 2:
        # inverse sigmoid
        probs_mean = tf.clip_by_value(probs_mean, eps, 1.0 - eps)
        logits = log_probs - tf.math.log(1.0 - probs_mean)
      else:
        logits = log_probs

      return MCSoftmaxResult(probs=probs_mean, log_probs=log_probs, locs=locs,
                             logits=logits, predictive_variance=pred_variance)


class MCSoftmaxDense(MCSoftmaxOutputLayer):
  """Monte Carlo estimation of softmax approx to heteroscedastic predictions."""

  def __init__(self, num_classes, logit_noise=het_lib_utils.LogitNoise.NORMAL,
               temperature=1.0, train_mc_samples=1000, test_mc_samples=1000,
               max_mc_samples_in_memory=1000, loc_regularizer=None,
               compute_pred_variance=False, name='MCSoftmaxDense'):
    """Creates an instance of MCSoftmaxDense.

    This is a MC softmax heteroscedastic drop in replacement for a
    tf.keras.layers.Dense output layer. e.g. simply change:

    logits = tf.keras.layers.Dense(...)(x)

    to

    logits = MCSoftmaxDense(...)(x).logits

    Args:
      num_classes: Integer. Number of classes for classification task.
      logit_noise: LogitNoise instance. The noise distribution
        assumed on the softmax logits. Possible values:
        LogitNoise.NORMAL, LogitNoise.LOGISTIC, LogitNoise.GUMBEL.
      temperature: Float or scalar `Tensor` representing the softmax
        temperature.
      train_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during training.
      test_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during testing/inference.
      max_mc_samples_in_memory: When estimating the predictive distribution a
        `Tensor` of shape [batch_size, max_mc_samples_in_memory, num_classes]
        will be computed. Set max_mc_samples_in_memory as high as possible for
        efficient computation but low enough such that OOM errors are avoided.
      loc_regularizer: Regularizer function applied to the kernel weights
        matrix of the fully connected layer computing the location parameter of
        the distribution on the logits.
      compute_pred_variance: Boolean. Whether to estimate the predictive
        variance. If False the __call__ method will output None for the
        predictive_variance tensor.
      name: String. The name of the layer used for name scoping.

    Returns:
      MCSoftmaxDense instance.
    """
    assert num_classes >= 2

    super(MCSoftmaxDense, self).__init__(
        num_classes, logit_noise=logit_noise, temperature=temperature,
        train_mc_samples=train_mc_samples, test_mc_samples=test_mc_samples,
        max_mc_samples_in_memory=max_mc_samples_in_memory,
        compute_pred_variance=compute_pred_variance, name=name)

    self._loc_layer = tf.keras.layers.Dense(
        1 if num_classes == 2 else num_classes, activation=None,
        kernel_regularizer=loc_regularizer, name='loc_layer')
    self._scale_layer = tf.keras.layers.Dense(
        1 if num_classes == 2 else num_classes,
        activation=tf.abs, name='scale_layer')

  def _compute_loc_param(self, inputs):
    """Computes location parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return self._loc_layer(inputs)

  def _compute_scale_param(self, inputs):
    """Computes scale parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return self._scale_layer(inputs)


class MCSoftmaxDenseFA(MCSoftmaxOutputLayer):
  """Softmax and factor analysis approx to heteroscedastic predictions."""

  def __init__(self, num_classes, num_factors,
               logit_noise=het_lib_utils.LogitNoise.NORMAL,
               temperature=1.0,
               parameter_efficient=False,
               train_mc_samples=1000,
               test_mc_samples=1000,
               max_mc_samples_in_memory=1000,
               compute_pred_variance=False,
               name='MCSoftmaxDenseFA'):
    """Creates an instance of MCSoftmaxDenseFA.

    if we assume:
    u(x) ~ N(mu(x), sigma(x))
    and
    y = softmax(u(x) / temperature)

    we can do a low rank approximation of sigma(x) the full rank matrix as:
    e ~ N(0, I_{RxR}), e_d ~ N(0, I_{KxK})
    if parameter_efficient:
      u = mu(x) + v(x) * matmul(V, e) + d(x) * e_d
      where V is a matrix of dimension [num_classes, num_factors]
      and v(x) is a [num_classes, 1] vector
    else:
      u = mu(x) + matmul(V(x), e) + d(x) * e_d
      where V(x) is a matrix of dimension [num_classes, num_factors]
    and d(x) is a vector of dimension [num_classes, 1]
    num_factors << num_classes => approx to sampling ~ N(mu(x), sigma(x))

    This is a MC softmax heteroscedastic drop in replacement for a
    tf.keras.layers.Dense output layer. e.g. simply change:

    logits = tf.keras.layers.Dense(...)(x)

    to

    logits = MCSoftmaxDenseFA(...)(x).logits

    Args:
      num_classes: Integer. Number of classes for classification task.
      num_factors: Integer. Number of factors to use in approximation to full
        rank covariance matrix.
      logit_noise: LogitNoise instance. The noise distribution
        assumed on the softmax logits. Possible values:
        LogitNoise.NORMAL, LogitNoise.LOGISTIC, LogitNoise.GUMBEL.
      temperature: Float or scalar `Tensor` representing the softmax
        temperature.
      parameter_efficient: Boolean. Whether to use the parameter efficient
        version of the method.
      train_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during training.
      test_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during testing/inference.
      max_mc_samples_in_memory: When estimating the predictive distribution a
        `Tensor` of shape [batch_size, max_mc_samples_in_memory, num_classes]
        will be computed. Set max_mc_samples_in_memory as high as possible for
        efficient computation but low enough such that OOM errors are avoided.
      compute_pred_variance: Boolean. Whether to estimate the predictive
        variance. If False the __call__ method will output None for the
        predictive_variance tensor.
      name: String. The name of the layer used for name scoping.

    Returns:
      MCSoftmaxDenseFA instance.
    """
    # no need to model correlations between classes in binary case
    assert num_classes > 2
    assert num_factors <= num_classes

    super(MCSoftmaxDenseFA, self).__init__(
        num_classes, logit_noise=logit_noise,
        temperature=temperature, train_mc_samples=train_mc_samples,
        test_mc_samples=test_mc_samples,
        max_mc_samples_in_memory=max_mc_samples_in_memory,
        compute_pred_variance=compute_pred_variance,
        scale_positive=False, name=name)

    self._num_factors = num_factors
    self._parameter_efficient = parameter_efficient

    self._loc_layer = tf.keras.layers.Dense(
        num_classes, kernel_initializer='he_normal', name='loc_layer')

    if parameter_efficient:
      self._scale_layer_a = tf.keras.layers.Dense(
          num_classes, name='scale_layer_a')
      self._scale_layer_b = tf.keras.layers.Dense(
          num_classes, name='scale_layer_b')
    else:
      self._scale_layer = tf.keras.layers.Dense(
          num_classes * num_factors, name='scale_layer')

    self._diag_layer = tf.keras.layers.Dense(
        num_classes, activation=tf.math.softplus, name='diag_layer')

  def _compute_loc_param(self, inputs):
    """Computes location parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return self._loc_layer(inputs)

  def _compute_scale_param(self, inputs):
    """Computes scale parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tuple of tensors of shape ([batch_size, ..., num_classes * num_factors],
      [batch_size, ..., num_classes]).
    """
    if self._parameter_efficient:
      return (inputs, self._diag_layer(inputs) + 1e-3)
    else:
      return (self._scale_layer(inputs), self._diag_layer(inputs) + 1e-3)

  def _compute_diagonal_noise_samples(self, diag_scale, num_samples, seed):
    """Compute samples of the diagonal elements logit noise.

    Args:
      diag_scale: `Tensor` of shape [batch_size, ..., num_classes]. Diagonal
        elements of scale parameters of the distribution to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Logit noise samples of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes].
    """
    if seed is None or isinstance(seed, int):
      dist = tfp.distributions.Normal(
          loc=tf.zeros_like(diag_scale), scale=tf.ones_like(diag_scale))

      tf.random.set_seed(seed)
      diag_noise_samples = dist.sample(num_samples, seed=seed)
    else:
      seed_delta = (max(self._train_mc_samples, self._test_mc_samples)//
                    self._max_mc_samples_in_memory) + 1
      # avoiding seed collisions over multiple calls to _compute_noise_samples
      diag_noise_samples = tf.random.stateless_normal(
          tf.concat([[num_samples], tf.shape(diag_scale)], axis=0),
          [seed, seed + seed_delta],
          mean=tf.zeros_like(tf.expand_dims(diag_scale, axis=0)),
          stddev=tf.ones_like(tf.expand_dims(diag_scale, axis=0)),
          dtype=diag_scale.dtype)

    # dist.sample(total_mc_samples) returns Tensor of shape
    # [total_mc_samples, batch_size, d], here we reshape to
    # [batch_size, total_mc_samples, d]
    diag_noise_samples = tf.transpose(
        diag_noise_samples,
        tf.concat([[1, 0], tf.range(2, tf.rank(diag_noise_samples))], 0))

    if self._logit_noise == het_lib_utils.LogitNoise.RADIAL:
      norm = tf.norm(diag_noise_samples, ord=2, axis=-1, keepdims=True)
      diag_noise_samples = diag_noise_samples/norm

    diag_noise_samples = tf.expand_dims(diag_scale, axis=1) * diag_noise_samples

    return diag_noise_samples

  def _compute_standard_normal_samples(self, samples_shape, dtype, num_samples,
                                       seed):
    """Utility function to compute samples from a standard normal distribution.

    Args:
      samples_shape: TensorShape.
      dtype: Valid Tensorflow dtype.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Samples of shape: [batch_size, num_samples, ..., num_factors].
    """
    if seed is None or isinstance(seed, int):
      if self._logit_noise == het_lib_utils.LogitNoise.LAPLACE:
        dist = tfp.distributions.Laplace(
            loc=tf.zeros(samples_shape, dtype=dtype),
            scale=tf.ones(samples_shape, dtype=dtype))
      else:
        dist = tfp.distributions.Normal(
            loc=tf.zeros(samples_shape, dtype=dtype),
            scale=tf.ones(samples_shape, dtype=dtype))

      tf.random.set_seed(seed)
      standard_normal_samples = dist.sample(num_samples, seed=seed)
    else:
      seed_delta = (max(self._train_mc_samples, self._test_mc_samples)//
                    self._max_mc_samples_in_memory) + 1
      # avoiding seed collisions over multiple calls to _compute_noise_samples
      standard_normal_samples = tf.random.stateless_normal(
          tf.concat([[num_samples], samples_shape], axis=0),
          [seed, seed + seed_delta],
          mean=tf.expand_dims(
              tf.zeros(samples_shape, dtype=dtype), 0),
          stddev=tf.expand_dims(
              tf.ones(samples_shape, dtype=dtype), 0),
          dtype=dtype)

    # dist.sample(total_mc_samples) returns Tensor of shape
    # [total_mc_samples, batch_size, d], here we reshape to
    # [batch_size, total_mc_samples, d]
    standard_normal_samples = tf.transpose(
        standard_normal_samples,
        tf.concat([[1, 0], tf.range(2, tf.rank(standard_normal_samples))], 0))

    return standard_normal_samples

  def _compute_noise_samples(self, scale, num_samples, seed):
    """Utility function to compute the samples of the logit noise.

    Args:
      scale: Tuple of tensors of shape (
        [batch_size, ..., num_classes * num_factors],
        [batch_size, ..., num_classes]). Factor loadings and diagonal elements
        for scale parameters of the distribution to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Logit noise samples of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes].
    """
    factor_loadings, diag_scale = scale

    # Compute the diagonal noise
    diag_noise_samples = self._compute_diagonal_noise_samples(diag_scale,
                                                              num_samples, seed)

    # Now compute the factors
    batch_size = tf.shape(diag_scale)[0]
    samples_shape = [batch_size, self._num_factors]
    standard_normal_samples = self._compute_standard_normal_samples(
        samples_shape, diag_noise_samples.dtype, num_samples, seed)

    if self._parameter_efficient:
      noise_samples = (tf.expand_dims(self._scale_layer_a(factor_loadings), 1) *
                       self._scale_layer_b(standard_normal_samples))
    else:
      # reshape scale vector into factor loadings matrix
      factor_loadings = tf.cast(
          tf.reshape(factor_loadings,
                     [-1, self._num_classes, self._num_factors]),
          standard_normal_samples.dtype)

      # transform standard normal into ~ full rank covariance Gaussian samples
      noise_samples = tf.einsum('ijk,iak->iaj', factor_loadings,
                                standard_normal_samples)

    noise_samples = noise_samples + diag_noise_samples

    return noise_samples


class MCSigmoidDenseFA(MCSoftmaxOutputLayer):
  """Sigmoid and factor analysis approx to heteroscedastic predictions."""

  def __init__(self, num_outputs, num_factors,
               logit_noise=het_lib_utils.LogitNoise.NORMAL,
               temperature=1.0,
               parameter_efficient=False,
               train_mc_samples=1000,
               test_mc_samples=1000,
               max_mc_samples_in_memory=1000,
               compute_pred_variance=False,
               name='MCSigmoidDenseFA'):
    """Creates an instance of MCSigmoidDenseFA.

    if we assume:
    u ~ N(mu(x), sigma(x))
    and
    y = sigmoid(u / temperature)

    we can do a low rank approximation of sigma(x) the full rank matrix as:
    e ~ N(0, I_{RxR}), e_d ~ N(0, I_{KxK})
    if parameter_efficient:
      u = mu(x) + v(x) * matmul(V, e) + d(x) * e_d
      where V is a matrix of dimension [num_classes, num_factors]
      and v(x) is a [num_classes, 1] vector
    else:
      u = mu(x) + matmul(V(x), e) + d(x) * e_d
      where V(x) is a matrix of dimension [num_classes, num_factors]
    and d(x) is a vector of dimension [num_classes, 1]
    num_factors << num_classes => approx to sampling ~ N(mu(x), sigma(x))

    This is a MC sigmoid heteroscedastic drop in replacement for a
    tf.keras.layers.Dense output layer. e.g. simply change:

    logits = tf.keras.layers.Dense(...)(x)

    to

    logits = MCSigmoidDenseFA(...)(x).logits

    Args:
      num_outputs: Integer. Number of outputs.
      num_factors: Integer. Number of factors to use in approximation to full
        rank covariance matrix.
      logit_noise: LogitNoise instance. The noise distribution
        assumed on the softmax logits. Possible values:
        LogitNoise.NORMAL, LogitNoise.LOGISTIC, LogitNoise.GUMBEL.
      temperature: Float or scalar `Tensor` representing the softmax
        temperature.
      parameter_efficient: Boolean. Whether to use the parameter efficient
        version of the method.
      train_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during training.
      test_mc_samples: The number of Monte-Carlo samples used to estimate the
        predictive distribution during testing/inference.
      max_mc_samples_in_memory: When estimating the predictive distribution a
        `Tensor` of shape [batch_size, max_mc_samples_in_memory, num_classes]
        will be computed. Set max_mc_samples_in_memory as high as possible for
        efficient computation but low enough such that OOM errors are avoided.
      compute_pred_variance: Boolean. Whether to estimate the predictive
        variance. If False the __call__ method will output None for the
        predictive_variance tensor.
      name: String. The name of the layer used for name scoping.

    Returns:
      MCSigmoidDenseFA instance.
    """
    assert num_factors <= num_outputs

    super(MCSigmoidDenseFA, self).__init__(
        2, logit_noise=logit_noise,
        temperature=temperature, train_mc_samples=train_mc_samples,
        test_mc_samples=test_mc_samples,
        max_mc_samples_in_memory=max_mc_samples_in_memory,
        compute_pred_variance=compute_pred_variance,
        scale_positive=False, name=name)

    self._num_factors = num_factors
    self._num_outputs = num_outputs
    self._parameter_efficient = parameter_efficient

    self._loc_layer = tf.keras.layers.Dense(num_outputs, name='loc_layer')

    if num_factors > 0:
      if parameter_efficient:
        self._scale_layer_a = tf.keras.layers.Dense(
            num_outputs, name='scale_layer_a')
        self._scale_layer_b = tf.keras.layers.Dense(
            num_outputs, name='scale_layer_b')
      else:
        self._scale_layer = tf.keras.layers.Dense(
            num_outputs * num_factors, name='scale_layer')

    self._diag_layer = tf.keras.layers.Dense(
        num_outputs, activation=tf.math.softplus, name='diag_layer',
        bias_initializer='zeros')

  def _compute_loc_param(self, inputs):
    """Computes location parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tensor of shape [batch_size, ..., num_classes].
    """
    return self._loc_layer(inputs)

  def _compute_scale_param(self, inputs):
    """Computes scale parameter of the "logits distribution".

    Args:
      inputs: Tensor. The input to the heteroscedastic output layer.

    Returns:
      Tuple of tensors of shape ([batch_size, ..., num_classes * num_factors],
      [batch_size, ..., num_classes]).
    """
    if self._num_factors > 0:
      if self._parameter_efficient:
        return (inputs, self._diag_layer(inputs)  + 1e-3)
      else:
        return (self._scale_layer(inputs), self._diag_layer(inputs)  + 1e-3)
    else:
      return (None, self._diag_layer(inputs) + 1e-3)

  def _compute_diagonal_noise_samples(self, diag_scale, num_samples, seed):
    """Compute samples of the diagonal elements logit noise.

    Args:
      diag_scale: `Tensor` of shape [batch_size, ..., num_classes]. Diagonal
        elements of scale parameters of the distribution to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Logit noise samples of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes].
    """
    if seed is None or isinstance(seed, int):
      dist = tfp.distributions.Normal(
          loc=tf.zeros_like(diag_scale), scale=tf.ones_like(diag_scale))

      tf.random.set_seed(seed)
      diag_noise_samples = dist.sample(num_samples, seed=seed)
    else:
      seed_delta = (max(self._train_mc_samples, self._test_mc_samples)//
                    self._max_mc_samples_in_memory) + 1
      # avoiding seed collisions over multiple calls to _compute_noise_samples
      diag_noise_samples = tf.random.stateless_normal(
          tf.concat([[num_samples], tf.shape(diag_scale)], axis=0),
          [seed, seed + seed_delta],
          mean=tf.zeros_like(tf.expand_dims(diag_scale, axis=0)),
          stddev=tf.ones_like(tf.expand_dims(diag_scale, axis=0)),
          dtype=diag_scale.dtype)

    # dist.sample(total_mc_samples) returns Tensor of shape
    # [total_mc_samples, batch_size, d], here we reshape to
    # [batch_size, total_mc_samples, d]
    diag_noise_samples = tf.transpose(
        diag_noise_samples,
        tf.concat([[1, 0], tf.range(2, tf.rank(diag_noise_samples))], 0))

    diag_noise_samples = tf.expand_dims(diag_scale, axis=1) * diag_noise_samples

    return diag_noise_samples

  def _compute_standard_normal_samples(self, factor_loadings, num_samples,
                                       seed):
    """Utility function to compute samples from a standard normal distribution.

    Args:
      factor_loadings: `Tensor` of shape
        [batch_size, ..., num_classes * num_factors]. Factor loadings for scale
        parameters of the distribution to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Samples of shape: [batch_size, num_samples, ..., num_factors].
    """
    samples_shape = tf.concat(
        [[tf.shape(factor_loadings)[0]], [self._num_factors]], axis=0)
    if seed is None or isinstance(seed, int):
      if self._logit_noise == het_lib_utils.LogitNoise.LAPLACE:
        dist = tfp.distributions.Laplace(
            loc=tf.zeros(samples_shape, dtype=factor_loadings.dtype),
            scale=tf.ones(samples_shape, dtype=factor_loadings.dtype))
      else:
        dist = tfp.distributions.Normal(
            loc=tf.zeros(samples_shape, dtype=factor_loadings.dtype),
            scale=tf.ones(samples_shape, dtype=factor_loadings.dtype))

      tf.random.set_seed(seed)
      standard_normal_samples = dist.sample(num_samples, seed=seed)
    else:
      seed_delta = (max(self._train_mc_samples, self._test_mc_samples)//
                    self._max_mc_samples_in_memory) + 1
      # avoiding seed collisions over multiple calls to _compute_noise_samples
      standard_normal_samples = tf.random.stateless_normal(
          tf.concat([[num_samples], samples_shape], axis=0),
          [seed, seed + seed_delta],
          mean=tf.expand_dims(
              tf.zeros(samples_shape, dtype=factor_loadings.dtype), 0),
          stddev=tf.expand_dims(
              tf.ones(samples_shape, dtype=factor_loadings.dtype), 0),
          dtype=factor_loadings.dtype)

    # dist.sample(total_mc_samples) returns Tensor of shape
    # [total_mc_samples, batch_size, d], here we reshape to
    # [batch_size, total_mc_samples, d]
    standard_normal_samples = tf.transpose(
        standard_normal_samples,
        tf.concat([[1, 0], tf.range(2, tf.rank(standard_normal_samples))], 0))

    return standard_normal_samples

  def _compute_noise_samples(self, scale, num_samples, seed):
    """Utility function to compute the samples of the logit noise.

    Args:
      scale: Tuple of tensors of shape (
        [batch_size, ..., num_classes * num_factors],
        [batch_size, ..., num_classes]). Factor loadings and diagonal elements
        for scale parameters of the distribution to be sampled.
      num_samples: Integer. Number of Monte-Carlo samples to take.
      seed: Python integer or Tensor for seeding the random number generator.

    Returns:
      `Tensor`. Logit noise samples of shape: [batch_size, num_samples, ...,
        1 if num_classes == 2 else num_classes].
    """
    factor_loadings, diag_scale = scale

    # Compute the diagonal noise
    diag_noise_samples = self._compute_diagonal_noise_samples(diag_scale,
                                                              num_samples, seed)

    # Now compute the factors
    if self._num_factors > 0:
      standard_normal_samples = self._compute_standard_normal_samples(
          factor_loadings, num_samples, seed)

      if self._parameter_efficient:
        noise_samples = (
            tf.expand_dims(self._scale_layer_a(factor_loadings), 1) *
            self._scale_layer_b(standard_normal_samples))
      else:
        # reshape scale vector into factor loadings matrix
        factor_loadings = tf.cast(
            tf.reshape(factor_loadings,
                       [-1, self._num_outputs, self._num_factors]),
            standard_normal_samples.dtype)

        # transform standard normal into ~ full rank covariance Gaussian samples
        noise_samples = tf.einsum('ijk,iak->iaj', factor_loadings,
                                  standard_normal_samples)

      noise_samples = noise_samples + diag_noise_samples
    else:
      noise_samples = diag_noise_samples

    return noise_samples
