# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Independent distribution class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib


class Independent(distribution_lib.Distribution):
  """Independent distribution from batch of distributions.

  This distribution is useful for regarding a collection of independent,
  non-identical distributions as a single random variable. For example, the
  `Indpendent` distribution composed of a collection of `Bernoulli`
  distributions might define a distribution over an image (where each
  `Bernoulli` is a distribution over each pixel).

  More precisely, a collection of `B` (independent) `E`-variate random variables
  (rv) `{X_1, ..., X_B}`, can be regarded as a `[B, E]`-variate random variable
  `(X_1, ..., X_B)` with probability
  `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
  probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.

  Similarly, the `Independent` distribution specifies a distribution over
  `[B, E]`-shaped events. It operates by reinterpreting the rightmost batch dims
  as part of the event dimensions. The `reduce_batch_ndims` parameter controls
  the number of batch dims which are absorbed as event dims;
  `reduce_batch_ndims < len(batch_shape)`.  For example, the `log_prob` function
  entails a `reduce_sum` over the rightmost `reduce_batch_ndims` after calling
  the base distribution's `log_prob`.  In other words, since the batch
  dimension(s) index independent distributions, the resultant multivariate will
  have independent components.

  #### Mathematical Details

  The probability function is,

  ```none
  prob(x; reduce_batch_ndims) = tf.reduce_prod(
      dist.prob(x),
      axis=-1-range(reduce_batch_ndims))
  ```

  #### Examples

  ```python
  ds = tf.contrib.distributions

  # Make independent distribution from a 2-batch Normal.
  ind = ds.Independent(
      distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
      reduce_batch_ndims=1)

  # All batch dims have been "absorbed" into event dims.
  ind.batch_shape  # ==> []
  ind.event_shape  # ==> [2]

  # Make independent distribution from a 2-batch bivariate Normal.
  ind = ds.Independent(
      distribution=ds.MultivariateNormalDiag(
          loc=[[-1., 1], [1, -1]],
          scale_identity_multiplier=[1., 0.5]),
      reduce_batch_ndims=1)

  # All batch dims have been "absorbed" into event dims.
  ind.batch_shape  # ==> []
  ind.event_shape  # ==> [2, 2]
  ```

  """

  def __init__(
      self, distribution, reduce_batch_ndims=1, validate_args=False, name=None):
    """Construct a `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reduce_batch_ndims: Scalar, integer number of rightmost batch dims which
        will be regard as event dims.
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reduce_batch_ndims` exceeds `distribution.batch_ndims`
    """
    parameters = locals()
    name = name or "Independent" + distribution.name
    self._distribution = distribution
    with ops.name_scope(name):
      reduce_batch_ndims = ops.convert_to_tensor(
          reduce_batch_ndims, dtype=dtypes.int32, name="reduce_batch_ndims")
      self._reduce_batch_ndims = reduce_batch_ndims
      self._static_reduce_batch_ndims = tensor_util.constant_value(
          reduce_batch_ndims)
      if self._static_reduce_batch_ndims is not None:
        self._reduce_batch_ndims = self._static_reduce_batch_ndims
      super(Independent, self).__init__(
          dtype=self._distribution.dtype,
          reparameterization_type=self._distribution.reparameterization_type,
          validate_args=validate_args,
          allow_nan_stats=self._distribution.allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              [reduce_batch_ndims] +
              distribution._graph_parents),  # pylint: disable=protected-access
          name=name)
      self._runtime_assertions = self._make_runtime_assertions(
          distribution, reduce_batch_ndims, validate_args)

  @property
  def distribution(self):
    return self._distribution

  @property
  def reduce_batch_ndims(self):
    return self._reduce_batch_ndims

  def _batch_shape_tensor(self):
    with ops.control_dependencies(self._runtime_assertions):
      batch_shape = self.distribution.batch_shape_tensor()
      batch_ndims = (batch_shape.shape[0].value
                     if batch_shape.shape.with_rank_at_least(1)[0].value
                     else array_ops.shape(batch_shape)[0])
      return batch_shape[:batch_ndims - self.reduce_batch_ndims]

  def _batch_shape(self):
    batch_shape = self.distribution.batch_shape
    if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
      return tensor_shape.TensorShape(None)
    d = batch_shape.ndims - self._static_reduce_batch_ndims
    return batch_shape[:d]

  def _event_shape_tensor(self):
    with ops.control_dependencies(self._runtime_assertions):
      batch_shape = self.distribution.batch_shape_tensor()
      batch_ndims = (batch_shape.shape[0].value
                     if batch_shape.shape.with_rank_at_least(1)[0].value
                     else array_ops.shape(batch_shape)[0])
      return array_ops.concat([
          batch_shape[batch_ndims - self.reduce_batch_ndims:],
          self.distribution.event_shape_tensor(),
      ], axis=0)

  def _event_shape(self):
    batch_shape = self.distribution.batch_shape
    if self._static_reduce_batch_ndims is None or batch_shape.ndims is None:
      return tensor_shape.TensorShape(None)
    d = batch_shape.ndims - self._static_reduce_batch_ndims
    return batch_shape[d:].concatenate(self.distribution.event_shape)

  def _sample_n(self, n, seed):
    with ops.control_dependencies(self._runtime_assertions):
      return self.distribution.sample(sample_shape=n, seed=seed)

  def _log_prob(self, x):
    with ops.control_dependencies(self._runtime_assertions):
      return self._reduce_sum(self.distribution.log_prob(x))

  def _entropy(self):
    with ops.control_dependencies(self._runtime_assertions):
      return self._reduce_sum(self.distribution.entropy())

  def _mean(self):
    with ops.control_dependencies(self._runtime_assertions):
      return self.distribution.mean()

  def _variance(self):
    with ops.control_dependencies(self._runtime_assertions):
      return self.distribution.variance()

  def _stddev(self):
    with ops.control_dependencies(self._runtime_assertions):
      return self.distribution.stddev()

  def _mode(self):
    with ops.control_dependencies(self._runtime_assertions):
      return self.distribution.mode()

  def _make_runtime_assertions(
      self, distribution, reduce_batch_ndims, validate_args):
    assertions = []
    static_reduce_batch_ndims = tensor_util.constant_value(reduce_batch_ndims)
    batch_ndims = distribution.batch_shape.ndims
    if batch_ndims is not None and static_reduce_batch_ndims is not None:
      if static_reduce_batch_ndims > batch_ndims:
        raise ValueError("reduce_batch_ndims({}) cannot exceed "
                         "distribution.batch_ndims({})".format(
                             static_reduce_batch_ndims, batch_ndims))
    elif validate_args:
      batch_shape = distribution.batch_shape_tensor()
      batch_ndims = (
          batch_shape.shape[0].value
          if batch_shape.shape.with_rank_at_least(1)[0].value is not None
          else array_ops.shape(batch_shape)[0])
      assertions.append(check_ops.assert_less_equal(
          reduce_batch_ndims, batch_ndims,
          message="reduce_batch_ndims cannot exceed distribution.batch_ndims"))
    return assertions

  def _reduce_sum(self, stat):
    if self._static_reduce_batch_ndims is None:
      range_ = array_ops.range(self._reduce_batch_ndims)
    else:
      range_ = np.arange(self._static_reduce_batch_ndims)
    return math_ops.reduce_sum(stat, axis=-1-range_)
