"""
 Copyright (C) 2018-2020 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""
import logging as log

import numpy as np

from extensions.ops.elementwise import Add, Mul
from mo.front.common.layout import get_features_dim
from mo.front.extractor import split_node_in_port
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error
from mo.utils.utils import refer_to_faq_msg


class AddMeanScaleValues(MiddleReplacementPattern):
    enabled = True
    run_not_recursively = True

    def run_after(self):
        return []

    def run_before(self):
        from extensions.middle.pass_separator import MiddleStart
        return [MiddleStart]

    @staticmethod
    def apply_scale(graph: Graph, input_node: Node, node_mean_scale_values: dict):
        if 'scale' in node_mean_scale_values and node_mean_scale_values['scale'] is not None:
            if all([x == 1 for x in node_mean_scale_values['scale']]):
                return
            value = 1 / np.array(node_mean_scale_values['scale'])

            assert input_node.has_valid('shape')
            features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape))
            assert value.size == input_node.shape[features_dim_idx] or value.size == 1

            shape = np.ones(len(input_node.shape), dtype=np.int64)
            shape[features_dim_idx] = value.size
            value = value.reshape(shape)

            name = input_node.soft_get('name', input_node.id) + '/scale_value'
            mul = create_op_with_const_inputs(graph, op=Mul, port_value_dict={1: value}, op_attrs={'name': name})

            for dst in input_node.out_port(0).get_destinations():
                if dst.node.soft_get('type') != 'ShapeOf':
                    dst.get_connection().set_source(mul.out_port(0))

            input_node.out_port(0).connect(mul.in_port(0))

    @staticmethod
    def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict):
        if 'mean' in node_mean_scale_values and node_mean_scale_values['mean'] is not None:
            if all([x == 0 for x in node_mean_scale_values['mean']]):
                return
            value = np.array(node_mean_scale_values['mean']) * (-1)

            assert input_node.has_valid('shape')
            features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape))
            assert value.size == input_node.shape[features_dim_idx] or value.size == 1

            shape = np.ones(len(input_node.shape), dtype=np.int64)
            shape[features_dim_idx] = value.size
            value = value.reshape(shape)

            name = input_node.soft_get('name', input_node.id) + '/mean_value'
            add = create_op_with_const_inputs(graph, op=Add, port_value_dict={1: value}, op_attrs={'name': name})

            for dst in input_node.out_port(0).get_destinations():
                if dst.node.soft_get('type') != 'ShapeOf':
                    dst.get_connection().set_source(add.out_port(0))

            input_node.out_port(0).connect(add.in_port(0))

    def find_and_replace_pattern(self, graph: Graph):
        input_nodes = {}
        values = graph.graph['cmd_params'].mean_scale_values
        for node in graph.nodes():
            node = Node(graph, node)
            if node.has_valid('op') and node.op == 'Parameter':
                input_nodes.update({node.id: node})

        if not isinstance(values, dict):
            if len(values) != len(input_nodes):
                raise Error('Numbers of inputs and mean/scale values do not match. ' +
                            refer_to_faq_msg(61))

            data = np.copy(values)
            values = {}
            for idx, key in enumerate(input_nodes.keys()):
                values.update(
                    {
                        input_nodes[key]['name']: {
                            'mean': data[idx][0],
                            'scale': data[idx][1]
                        }
                    }
                )

        for node_name in values:
            node_mean_scale_values = values[node_name]
            node_name, port = split_node_in_port(node_name)
            node_id = None
            try:
                node_id = graph.get_node_id_by_name(node_name)
            except Error as e:
                log.warning('node_name {} is not found in graph'.format(node_name))
            if node_id not in input_nodes:
                # if the user cutted-off input of the network then input node name specified in the --scale_values
                # or --mean_values doesn't correspond to a real input node generated by Model Optimizer. But
                # the information about initial input node name is stored in Placeholder's attribute 'initial_node_name'
                new_node_id = None
                for placeholder in input_nodes.values():
                    try:
                        placeholder_port = int(placeholder.id.split("_")[-1])
                    except Exception as ex:
                        log.debug('Can not get the port number from the node {}'.format(placeholder.id))
                        log.debug('Port will be defined as None')
                        port = None
                    if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_name and (
                            port is None or placeholder_port == port):
                        new_node_id = placeholder.id
                        break
                if new_node_id is None:
                    raise Error('Input with name {} wasn\'t found!'.format(node_name) +
                                refer_to_faq_msg(83))
                node_id = new_node_id

            input_node = Node(graph, node_id)
            AddMeanScaleValues.apply_scale(graph, input_node, node_mean_scale_values)
            AddMeanScaleValues.apply_mean_value(graph, input_node, node_mean_scale_values)
