#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_

#include <string>
#include <unordered_map>

#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {

// Produce a human-readable version of a NodeDef that is more concise
// than a text-format proto.
string SummarizeNodeDef(const NodeDef& node_def);

typedef protobuf::Map<string, AttrValue> AttrValueMap;

// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
template <class T>
void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
  AttrValue attr_value;
  SetAttrValue(std::forward<T>(value), &attr_value);
  node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
}

// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
void AddNodeAttr(const string& name, std::initializer_list<T> value,
                 NodeDef* node_def) {
  AttrValue attr_value;
  SetAttrValue(value, &attr_value);
  node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
}

class AttrSlice {
 public:
  AttrSlice(const NodeDef& node_def)  // NOLINT(runtime/explicit)
      : ndef_(&node_def),
        attrs_(&ndef_->attr()) {}

  explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}

  // Returns the attr with attr_name if found.  Otherwise, returns
  // nullptr.
  const AttrValue* Find(const string& attr_name) const;

  // Returns the attr_value for attr_name if found. Otherwise, returns a
  // NotFound status.
  Status Find(const string& attr_name, const AttrValue** attr_value) const;

 private:
  const NodeDef* ndef_ = nullptr;
  const AttrValueMap* attrs_;
};

// Look up the attr with name attr_name and set *value to its value.  If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, a non-ok status will be returned.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   string* value);  // type: "string"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   int64* value);  // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   int32* value);  // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   float* value);  // type: "float"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   bool* value);  // type: "bool"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   DataType* value);  // type: "type"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   TensorShapeProto* value);  // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   TensorShape* value);  // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   Tensor* value);  // type: "tensor"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<string>* value);  // type "list(string)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<int64>* value);  // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<int32>* value);  // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<float>* value);  // type "list(float)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<bool>* value);  // type "list(bool)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<DataType>* value);  // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   DataTypeVector* value);  // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<TensorShapeProto>* value);  // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<TensorShape>* value);  // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   std::vector<Tensor>* value);  // type: "list(tensor)"

// This version avoids copying the TensorProto.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   const TensorProto** value);  // type: "tensor"

// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
                   const NameAttrList** value);  // type: "func"

// Computes the input and output types for a specific node, for
// attr-style ops.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                         DataTypeVector* inputs, DataTypeVector* outputs);

// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
// * All attrs satisfies constraints from the OpDef.
// * Has a signature matching SignatureForNode().
// etc.
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);

// Computes the mapping from input/output argument name to the
// corresponding input/output index range.  For example,
// input "foo" coresponds to input indices
//   [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
                         NameRangeMap* inputs, NameRangeMap* outputs);

// Adds default values to *node_def for unspecified attrs from op_def.
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);

// Validates the syntax of a NodeDef provided externally.
//
// The following is an EBNF-style syntax for NodeDef objects. Note that
// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
// which contain many other fields that are not (currently) validated.
//
// Node         = NodeName, Inputs
// Inputs       = ( DataInput * ), ( ControlInput * )
// DataInput    = NodeName, ( ":", [1-9], [0-9] * ) ?
// ControlInput = "^", NodeName
// NodeName     = [A-Za-z0-9.], [A-Za-z0-9_./] *
Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);

// Returns "status" with kernel's NodeDef attached as additional text
// in the error message.
Status AttachDef(const Status& status, const NodeDef& node_def);

}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
