# coding=utf-8
# Copyright 2020 The Heteroscedastic Noisy Labels Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Shared utilities between heteroscedastic_lib.py."""

import heteroscedastic_lib_utils as het_lib_utils

import tensorflow.compat.v2 as tf


def compute_predictive_dist(locs, scale, total_mc_samples, seed,
                            compute_mc_samples, max_mc_samples_in_memory):
  """Utility function to compute the estimated predictive distribution.

  Args:
    locs: Tensor of shape [batch_size, total_mc_samples, ...,
      1 if num_classes == 2 else num_classes]. Location parameters of the
      distributions to be sampled.
    scale: Tensor of shape [batch_size, total_mc_samples, ...,
      1 if num_classes == 2 else num_classes]. Scale parameters of the
      distributions to be sampled.
    total_mc_samples: Integer. Number of Monte-Carlo samples.
    seed: Python integer or scalar Tensor initial seed, for seeding the random
      number generator.
    compute_mc_samples: function which computes Monte-Carlo samples of selected
      location-scale distributions.
    max_mc_samples_in_memory: Integer.

  Returns:
    Tupe of (samples_mean, seeds_list, samples_list).  Where `samples_mean` is a
    Tensor of shape [batch_size, ..., d] - the mean of the MC samples.
    `seeds_list` is a list of Python integer or scalar Tensor seeds used to
    generate the random samples. `samples_list` is a list of integers containing
    the number of samples taken in each batch, sum(samples) == total_mc_samples.
  """
  seeds_list = []
  num_samples_list = []
  if total_mc_samples <= max_mc_samples_in_memory:
    if seed is None:
      seed = het_lib_utils.gen_int_seed()
    samples = compute_mc_samples(locs, scale, total_mc_samples, seed)

    seeds_list.append(seed)
    num_samples_list.append(total_mc_samples)
  else:
    # divide total_mc_samples into batches of samples that fit in memory
    # need (total_mc_samples // self._max_mc_samples_in_memory) batches of
    # size self._max_mc_samples_in_memory and maybe 1 additional batch of size
    # total_mc_samples % self._max_mc_samples_in_memory
    same_sample_batches = total_mc_samples // max_mc_samples_in_memory
    num_samples = [max_mc_samples_in_memory] * same_sample_batches
    sampling_weights = ([max_mc_samples_in_memory / float(total_mc_samples)] *
                        same_sample_batches)
    if total_mc_samples % max_mc_samples_in_memory > 0:
      remainder_samples = total_mc_samples % max_mc_samples_in_memory
      num_samples.append(remainder_samples)
      sampling_weights.append(remainder_samples / float(total_mc_samples))

    compute_mc_samples = tf.recompute_grad(compute_mc_samples)
    if seed is None:
      seed = het_lib_utils.gen_tensor_seed()

    samples = None
    for i, (sampling_ops, weight) in enumerate(
        zip(num_samples, sampling_weights)):
      seed = seed + 1  # unique seed for each set of samples
      unweighted_probs = compute_mc_samples(locs, scale, sampling_ops, seed)
      if i == 0:
        samples = weight * unweighted_probs
      else:
        samples = samples + weight * unweighted_probs

      seeds_list.append(seed)
      num_samples_list.append(sampling_ops)

  return tf.reduce_mean(samples, axis=1), seeds_list, num_samples_list
