# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A small library of functions dealing with LSTMs applied to images.

Tensors in this library generally have the shape (num_images, height, width,
depth).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.ndlstm.python import lstm1d
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope


def _shape(tensor):
  """Get the shape of a tensor as an int list."""
  return tensor.get_shape().as_list()


def images_to_sequence(tensor):
  """Convert a batch of images into a batch of sequences.

  Args:
    tensor: a (num_images, height, width, depth) tensor

  Returns:
    (width, num_images*height, depth) sequence tensor
  """

  num_image_batches, height, width, depth = _shape(tensor)
  transposed = array_ops.transpose(tensor, [2, 0, 1, 3])
  return array_ops.reshape(transposed,
                           [width, num_image_batches * height, depth])


def sequence_to_images(tensor, num_image_batches):
  """Convert a batch of sequences into a batch of images.

  Args:
    tensor: (num_steps, num_batches, depth) sequence tensor
    num_image_batches: the number of image batches

  Returns:
    (num_images, height, width, depth) tensor
  """

  width, num_batches, depth = _shape(tensor)
  height = num_batches // num_image_batches
  reshaped = array_ops.reshape(tensor,
                               [width, num_image_batches, height, depth])
  return array_ops.transpose(reshaped, [1, 2, 0, 3])


def horizontal_lstm(images, num_filters_out, scope=None):
  """Run an LSTM bidirectionally over all the rows of each image.

  Args:
    images: (num_images, height, width, depth) tensor
    num_filters_out: output depth
    scope: optional scope name

  Returns:
    (num_images, height, width, num_filters_out) tensor, where
    num_steps is width and new num_batches is num_image_batches * height
  """
  with variable_scope.variable_scope(scope, "HorizontalLstm", [images]):
    batch_size, _, _, _ = _shape(images)
    sequence = images_to_sequence(images)
    with variable_scope.variable_scope("lr"):
      hidden_sequence_lr = lstm1d.ndlstm_base(sequence, num_filters_out // 2)
    with variable_scope.variable_scope("rl"):
      hidden_sequence_rl = (lstm1d.ndlstm_base(
          sequence, num_filters_out - num_filters_out // 2, reverse=1))
    output_sequence = array_ops.concat([hidden_sequence_lr, hidden_sequence_rl],
                                       2)
    output = sequence_to_images(output_sequence, batch_size)
    return output


def get_blocks(images, kernel_size):
  """Split images in blocks

  Args:
    images: (num_images, height, width, depth) tensor
    kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of
      of the pooling. Can be an int if both values are the same.

  Returns:
    (num_images, height/kernel_height, width/kernel_width,
    depth*kernel_height*kernel_width) tensor
  """
  with variable_scope.variable_scope("image_blocks"):
    batch_size, height, width, chanels = _shape(images)

    if height % kernel_size[0] != 0:
      offset = array_ops.zeros([batch_size,
                                kernel_size[0] - (height % kernel_size[0]),
                                width,
                                chanels])
      images = array_ops.concat([images, offset], 1)
      batch_size, height, width, chanels = _shape(images)
    if width % kernel_size[1] != 0:
      offset = array_ops.zeros([batch_size,
                                height,
                                kernel_size[1] - (width % kernel_size[1]),
                                chanels])
      images = array_ops.concat([images, offset], 2)
      batch_size, height, width, chanels = _shape(images)

    h, w = int(height / kernel_size[0]), int(width / kernel_size[1])
    features = kernel_size[1] * kernel_size[0] * chanels

    lines = array_ops.split(images, h, axis=1)
    line_blocks = []
    for line in lines:
      line = array_ops.transpose(line, [0, 2, 3, 1])
      line = array_ops.reshape(line, [batch_size, w, features])
      line_blocks.append(line)

    return array_ops.stack(line_blocks, axis=1)


def separable_lstm(images, num_filters_out,
                   kernel_size=None, nhidden=None, scope=None):
  """Run bidirectional LSTMs first horizontally then vertically.

  Args:
    images: (num_images, height, width, depth) tensor
    num_filters_out: output layer depth
    kernel_size: A list of length 2 holding the [kernel_height, kernel_width] of
      of the pooling. Can be an int if both values are the same. Set to None for
      not using blocks
    nhidden: hidden layer depth
    scope: optional scope name

  Returns:
    (num_images, height/kernel_height, width/kernel_width,
    num_filters_out) tensor
  """
  with variable_scope.variable_scope(scope, "SeparableLstm", [images]):
    if nhidden is None:
      nhidden = num_filters_out
    if kernel_size is not None:
      images = get_blocks(images, kernel_size)
    hidden = horizontal_lstm(images, nhidden)
    with variable_scope.variable_scope("vertical"):
      transposed = array_ops.transpose(hidden, [0, 2, 1, 3])
      output_transposed = horizontal_lstm(transposed, num_filters_out)
    output = array_ops.transpose(output_transposed, [0, 2, 1, 3])
    return output


def reduce_to_sequence(images, num_filters_out, scope=None):
  """Reduce an image to a sequence by scanning an LSTM vertically.

  Args:
    images: (num_images, height, width, depth) tensor
    num_filters_out: output layer depth
    scope: optional scope name

  Returns:
    A (width, num_images, num_filters_out) sequence.
  """
  with variable_scope.variable_scope(scope, "ReduceToSequence", [images]):
    batch_size, height, width, depth = _shape(images)
    transposed = array_ops.transpose(images, [1, 0, 2, 3])
    reshaped = array_ops.reshape(transposed,
                                 [height, batch_size * width, depth])
    reduced = lstm1d.sequence_to_final(reshaped, num_filters_out)
    output = array_ops.reshape(reduced, [batch_size, width, num_filters_out])
    return output


def reduce_to_final(images, num_filters_out, nhidden=None, scope=None):
  """Reduce an image to a final state by running two LSTMs.

  Args:
    images: (num_images, height, width, depth) tensor
    num_filters_out: output layer depth
    nhidden: hidden layer depth (defaults to num_filters_out)
    scope: optional scope name

  Returns:
    A (num_images, num_filters_out) batch.
  """
  with variable_scope.variable_scope(scope, "ReduceToFinal", [images]):
    nhidden = nhidden or num_filters_out
    batch_size, height, width, depth = _shape(images)
    transposed = array_ops.transpose(images, [1, 0, 2, 3])
    reshaped = array_ops.reshape(transposed,
                                 [height, batch_size * width, depth])
    with variable_scope.variable_scope("reduce1"):
      reduced = lstm1d.sequence_to_final(reshaped, nhidden)
      transposed_hidden = array_ops.reshape(reduced,
                                            [batch_size, width, nhidden])
      hidden = array_ops.transpose(transposed_hidden, [1, 0, 2])
    with variable_scope.variable_scope("reduce2"):
      output = lstm1d.sequence_to_final(hidden, num_filters_out)
    return output
