//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/factory.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/specialize_function.hpp"

using namespace std;
using namespace ngraph;

constexpr NodeTypeInfo op::v0::TensorIterator::type_info;

op::v0::TensorIterator::TensorIterator(const OutputVector& values)
    : op::util::SubGraphOp(values)
{
}

bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
{
    visitor.on_attribute("body", m_body);
    visitor.on_attribute("input_descriptions", m_input_descriptions);
    visitor.on_attribute("output_descriptions", m_output_descriptions);

    return false;
}

void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops()
{
    std::stack<std::shared_ptr<Node>, std::vector<std::shared_ptr<Node>>> nodes_to_do;
    std::unordered_set<std::shared_ptr<Node>> nodes_done;

    for (const auto& r : m_body->get_results())
    {
        nodes_to_do.push(r);
    }
    while (nodes_to_do.size() > 0)
    {
        auto node = nodes_to_do.top();
        if (nodes_done.count(node) == 0)
        {
            NGRAPH_CHECK(as_type_ptr<op::v0::TensorIterator>(node) == nullptr,
                         "No nested TensorIterator");
            bool can_add = true;
            size_t arg_count = node->get_input_size();
            for (size_t i = 0; i < arg_count; ++i)
            {
                auto dep = node->input(arg_count - i - 1)
                               .get_source_output()
                               .get_node()
                               ->shared_from_this();
                if (nodes_done.count(dep) == 0)
                {
                    can_add = false;
                    nodes_to_do.push(dep);
                }
            }
            if (can_add)
            {
                nodes_done.insert(node);
                node->revalidate_and_infer_types();
                nodes_to_do.pop();
            }
        }
        else
        {
            nodes_to_do.pop();
        }
    }
}

void op::v0::TensorIterator::validate_and_infer_types()
{
    NODE_VALIDATION_CHECK(this,
                          get_input_size() == m_input_descriptions.size(),
                          "Number of inputs must be the same as number of input descriptions");

    NODE_VALIDATION_CHECK(this,
                          get_output_size() == m_output_descriptions.size(),
                          "Number of outputs must be the same as number of output descriptions");

    std::vector<std::shared_ptr<Node>> ends;

    auto make_positive = [](int64_t value, uint64_t dim_size) -> int64_t {
        if (value < 0)
        {
            value = dim_size + value;
        }
        return value;
    };

    // Input
    uint64_t index_it = 0;
    for (const auto& input_description : m_input_descriptions)
    {
        auto index = input_description->m_input_index;
        NODE_VALIDATION_CHECK(this, index == index_it, "Input_index not in order");
        index_it++;

        if (auto slice_input_description = as_type_ptr<SliceInputDescription>(input_description))
        {
            auto body_parameter =
                m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
            auto body_param_partial_shape = body_parameter->get_partial_shape();
            auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
            if (input_partial_shape.is_static())
            {
                auto input_shape = input_partial_shape.to_shape();
                auto axis = slice_input_description->m_axis;
                auto part_size = slice_input_description->m_part_size;

                auto dim_size = input_shape[axis];
                auto start = make_positive(slice_input_description->m_start, dim_size);
                auto end = make_positive(slice_input_description->m_end, dim_size);

                // +1 because the left and right borders are included [start, end]
                m_num_iterations = (abs(end - start) + 1) / part_size;
                // infer type for m_body_parameter
                Shape out_shape{input_shape};
                out_shape[axis] = part_size;
                body_parameter->set_partial_shape(out_shape);
            }
            else
            {
                body_parameter->set_partial_shape(
                    PartialShape::dynamic(input_partial_shape.rank()));
            }
        }
        else if (auto merged_input_description =
                     as_type_ptr<MergedInputDescription>(input_description))
        {
            auto body_value =
                m_body->get_results().at(merged_input_description->m_body_value_index)->input(0);
            ends.push_back(body_value.get_node()->shared_from_this());

            auto body_value_partial_shape = body_value.get_partial_shape();
            auto body_parameter =
                m_body->get_parameters().at(merged_input_description->m_body_parameter_index);

            auto body_param_partial_shape = body_parameter->get_partial_shape();
            auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
            body_parameter->set_partial_shape(input_partial_shape);
        }
        else if (auto invariant_input_description =
                     as_type_ptr<InvariantInputDescription>(input_description))
        {
            auto body_parameter =
                m_body->get_parameters().at(invariant_input_description->m_body_parameter_index);

            auto body_param_partial_shape = body_parameter->get_partial_shape();
            auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
            NODE_VALIDATION_CHECK(this,
                                  input_partial_shape.compatible(body_param_partial_shape),
                                  "Iterator initial value is not compatible with body param");
            body_parameter->set_partial_shape(input_partial_shape);
        }
    }

    // Body
    revalidate_and_infer_types_for_body_ops();

    // Output
    index_it = 0;
    for (const auto& output_description : m_output_descriptions)
    {
        auto index = output_description->m_output_index;
        NODE_VALIDATION_CHECK(this, index == index_it, "Output_index not in order");
        index_it++;

        auto body_value =
            m_body->get_results().at(output_description->m_body_value_index)->input_value(0);

        if (auto concat_output_description =
                as_type_ptr<ConcatOutputDescription>(output_description))
        {
            auto body_value_partial_shape = body_value.get_partial_shape();
            set_output_type(index, body_value.get_element_type(), PartialShape::dynamic());
            if (body_value_partial_shape.is_static())
            {
                auto body_value_shape = body_value_partial_shape.to_shape();
                auto part_size = concat_output_description->m_part_size;
                auto axis = concat_output_description->m_axis;

                Shape out_shape{body_value_shape};

                if (body_value_shape.empty())
                {
                    NODE_VALIDATION_CHECK(
                        this,
                        axis == 0,
                        "Axis must be equal to 0 if concatenated output tensor slices are scalars. "
                        "TensorIterator output index: ",
                        index);
                    out_shape = Shape(1);
                }

                if (m_num_iterations != -1)
                {
                    // for simple RNN case where stride is the same as part_size
                    out_shape[axis] = m_num_iterations * part_size;
                    set_output_type(index, body_value.get_element_type(), out_shape);
                }
            }
            else
            {
                set_output_type(index,
                                body_value.get_element_type(),
                                PartialShape::dynamic(body_value.get_partial_shape().rank()));
            }
        }
        else if (auto body_output_description =
                     as_type_ptr<BodyOutputDescription>(output_description))
        {
            set_output_type(index, body_value.get_element_type(), body_value.get_partial_shape());
        }
    }
}

std::shared_ptr<Function> op::v0::TensorIterator::get_function()
{
    return get_body();
}

std::shared_ptr<Node>
    op::v0::TensorIterator::clone_with_new_inputs(const OutputVector& new_args) const
{
    auto op = make_shared<op::v0::TensorIterator>(new_args);
    NGRAPH_CHECK(op.get(),
                 op != nullptr,
                 "Cannot clone ",
                 description(),
                 " operation with name ",
                 get_friendly_name());
    op->set_output_size(m_output_descriptions.size());

    std::vector<::ngraph::element::Type> types(m_body->get_parameters().size());
    std::vector<::ngraph::PartialShape> new_shapes(m_body->get_parameters().size());

    for (size_t input_index = 0; input_index < new_args.size(); ++input_index)
    {
        for (auto& input_description : m_input_descriptions)
        {
            if (input_description->m_input_index == input_index)
            {
                types[input_description->m_body_parameter_index] =
                    new_args[input_index].get_element_type();
                new_shapes[input_description->m_body_parameter_index] =
                    new_args[input_index].get_partial_shape();

                if (new_shapes[input_description->m_body_parameter_index].is_static())
                {
                    if (auto slice_in = ::ngraph::as_type_ptr<
                            ngraph::op::v0::TensorIterator::SliceInputDescription>(
                            input_description))
                    {
                        new_shapes[slice_in->m_body_parameter_index][slice_in->m_axis] =
                            slice_in->m_part_size;
                    }
                }
            }
        }
    }

    op->m_num_iterations = m_num_iterations;
    auto func = std::make_shared<Function>(
        m_body->get_results(), m_body->get_sinks(), m_body->get_parameters());
    auto spec_func =
        specialize_function(func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr));
    op->m_body = std::make_shared<Function>(
        spec_func->get_results(), spec_func->get_sinks(), spec_func->get_parameters());

    for (auto& input_description : m_input_descriptions)
    {
        op->m_input_descriptions.push_back(input_description->copy());
    }
    for (auto& output_description : m_output_descriptions)
    {
        op->m_output_descriptions.push_back(output_description->copy());
    }
    return move(op);
}
