/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"

#include <algorithm>
#include <deque>
#include <stack>
#include <unordered_set>
#include <vector>

#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/lib/gtl/optional.h"

namespace tensorflow {

namespace {

const char* const kArgOp = "_Arg";
const char* const kRetValOp = "_Retval";

// Information about a loop argument.
struct Arg {
  // Every loop argument has an Enter node.
  Node* enter;

  // Is the loop argument a loop-invariant value? Taken from the `is_constant`
  // attribute on the Enter node.
  bool is_loop_invariant;

  // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant
  // arguments must have all of the following nodes:
  Node* merge = nullptr;
  Node* switch_node = nullptr;
  Node* next_iteration = nullptr;
  Node* exit = nullptr;
};

// Information about a loop frame.
struct Frame {
  string name;

  // Pointer to the parent frame. The root frame has a pointer to itself.
  Frame* parent = nullptr;
  int num_children = 0;

  // Arguments to this loop.
  std::vector<Arg> args;

  // The loop condition of the loop. There should be exactly one loop condition
  // in every loop.
  Node* loop_cond = nullptr;

  // Set of nodes that belong to the loop frame.
  std::unordered_set<Node*> nodes;
};

// Returns a textual representation of the names of the nodes in the input.
template <typename T>
string NodesToString(const T& nodes) {
  return strings::StrCat("{",
                         str_util::Join(nodes, ",",
                                        [](string* output, const Node* node) {
                                          strings::StrAppend(output,
                                                             node->name());
                                        }),
                         "}");
}

// Copies a subgraph from `graph` to `output` by performing a reverse DFS
// starting at nodes in vector `stack`.
// `node_map` is a vector indexed by source node ID to dest nodes.
// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
// before the traversal clients can cut the graph. If a frame is provided (frame
// != nullptr), then this functions will return an error if the
// traversal leaves 'frame'; the client must add enough nodes to `node_map` to
// cut the graph and prevent the traversal from escaping.
//
// `squash_src_outputs` contains a bool for each source node ID. If true, then
// the source output on that node will be replaced by zero when copied. This is
// used when replacing a Switch node with an _Arg node. The output we are
// taking from the Switch node was not necessarily the first output, but _Arg
// nodes only have one output. By adding the Switch node to `squash_src_outputs`
// we rewrite the src_output of the corresponding edge to be 0.
Status CopySubgraph(const Graph& graph, const Frame* frame,
                    std::vector<Node*> stack,
                    const std::vector<bool>& squash_src_outputs,
                    std::vector<Node*>* node_map, Graph* output) {
  VLOG(3) << "Stack: " << NodesToString(stack);
  std::vector<bool> visited(graph.num_node_ids(), false);
  while (!stack.empty()) {
    Node* n = stack.back();
    stack.pop_back();

    VLOG(5) << "Copying node " << n->name();

    if (visited[n->id()]) continue;
    visited[n->id()] = true;

    for (const Edge* e : n->in_edges()) {
      Node* src = e->src();
      if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
        // We traversed out of the loop frame, without encountering a cut node.
        return errors::Internal("Graph traversal of loop frame ", frame->name,
                                " escaped frame at ", src->name(),
                                " without encountering an argument node.");
      }
      if ((*node_map)[src->id()] == nullptr) {
        (*node_map)[src->id()] = output->CopyNode(src);
        stack.push_back(src);
      }
      Node* src_copy = (*node_map)[e->src()->id()];
      int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
                           ? 0
                           : e->src_output();
      Node* dst_copy = (*node_map)[e->dst()->id()];
      output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
    }
  }
  return Status::OK();
}

xla::StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
  Status status;
  Node* inserted_node = graph->AddNode(node_def, &status);
  if (!status.ok()) {
    return status;
  }
  return inserted_node;
}

xla::StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
  NodeDef arg_def;
  NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
  builder.Attr("T", type);
  builder.Attr("index", index);
  TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
  return AddNode(arg_def, graph);
}

xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
  NodeDef ret_def;
  ret_def.set_op(kRetValOp);
  ret_def.set_name(strings::StrCat(kRetValOp, index));
  AddNodeAttr("T", type, &ret_def);
  AddNodeAttr("index", index, &ret_def);
  return AddNode(ret_def, graph);
}

// Builds a graph for the loop condition.
Status BuildLoopCondition(const Graph& graph, Frame* frame,
                          std::unique_ptr<Graph>* cond_output) {
  VLOG(2) << "Building loop condition for " << frame->name;
  *cond_output = xla::MakeUnique<Graph>(graph.op_registry());
  Graph* output = cond_output->get();

  // Map from nodes in the original graph to the condition graph.
  std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
  std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);

  // Build one _Arg node for each Enter node.
  for (int i = 0; i < frame->args.size(); ++i) {
    const Arg& arg = frame->args[i];

    TF_ASSIGN_OR_RETURN(Node * arg_node,
                        BuildArgNode(output, arg.enter->input_type(0), i));
    if (arg.is_loop_invariant) {
      node_map[arg.enter->id()] = arg_node;
    } else {
      node_map[arg.merge->id()] = arg_node;
    }
  }

  // Build a Retval node for the loop condition. The LoopCond nodes are always
  // boolean because of the type constraints on the LoopCond op.
  TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
                      BuildRetvalNode(output, DT_BOOL, 0));

  // Performs a reverse DFS, copying nodes and edges to the output graph.
  // The _Arg and _Retval nodes were added unconditionally above, so we are
  // guaranteed to get the correct function signature.
  return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
                      &node_map, output);
}

// Builds a graph for the loop body.
Status BuildLoopBody(const Graph& graph, Frame* frame,
                     DataTypeVector* arg_types,
                     std::unique_ptr<Graph>* body_output) {
  VLOG(2) << "Building loop body for " << frame->name;
  *body_output = xla::MakeUnique<Graph>(graph.op_registry());
  Graph* output = body_output->get();

  // Map from nodes in the original graph to the condition graph.
  std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
  std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);

  // Build one _Arg node for each Enter node.
  std::vector<Node*> next_iterations;
  next_iterations.reserve(frame->args.size());
  arg_types->reserve(frame->args.size());
  for (int i = 0; i < frame->args.size(); ++i) {
    const Arg& arg = frame->args[i];

    DataType dtype = arg.enter->input_type(0);
    arg_types->push_back(dtype);

    TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));

    if (dtype == DT_RESOURCE) {
      // The convention of the XLA bridge is that resource variable arguments
      // are only inputs to the loop body and have no corresponding output.
      // TODO(b/37741920): change the convention so that DT_RESOURCE variables
      // are both inputs and outputs, and then remove this case.
      TF_RET_CHECK(arg.is_loop_invariant);
      node_map[arg.enter->id()] = arg_node;
    } else {
      TF_ASSIGN_OR_RETURN(Node * retval_node,
                          BuildRetvalNode(output, dtype, i));

      if (arg.is_loop_invariant) {
        // Argument is loop-invariant. Forward it from the Arg to the Retval.
        node_map[arg.enter->id()] = arg_node;
        output->AddEdge(arg_node, 0, retval_node, 0);
      } else {
        // Argument is loop-varying.
        node_map[arg.switch_node->id()] = arg_node;
        // The Switch node has two outputs, but _Arg only has one. This tells
        // the CopySubgraph function to rewrite the output number of edges from
        // the _Arg node to be 0 rather than copying the output number from the
        // Switch node.
        squash_src_outputs[arg.switch_node->id()] = true;
        node_map[arg.next_iteration->id()] = retval_node;
        next_iterations.push_back(arg.next_iteration);
      }
    }
  }

  // Performs a reverse DFS, copying nodes and edges to the output graph.
  // The _Arg and _Retval nodes were added unconditionally above, so we are
  // guaranteed to get the correct function signature.
  TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
                                  squash_src_outputs, &node_map, output));

  return Status::OK();
}

Status FunctionalizeLoop(Graph* graph, Frame* frame,
                         FunctionLibraryDefinition* library) {
  VLOG(2) << "Frame " << frame->name << " before: "
          << dump_graph::DumpGraphToFile("functionalize_before", *graph);

  // Split loop-varying Enter nodes with multiple successors. If the same
  // Tensor is fed as input to multiple loop arguments, we may end up with a
  // shared Enter node. We clone Enter nodes with multiple successors to
  // maintain the invariant of a unique Enter node per argument of the final
  // loop.
  std::vector<Arg> args;
  for (const Arg& arg : frame->args) {
    if (arg.is_loop_invariant) {
      args.push_back(arg);
    } else {
      std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
                                     arg.enter->out_edges().end());
      for (int i = 0; i < edges.size(); ++i) {
        if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
          continue;
        }
        TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
        Arg new_arg;
        new_arg.is_loop_invariant = false;
        if (i == 0) {
          new_arg.enter = arg.enter;
        } else {
          new_arg.enter = graph->CopyNode(arg.enter);
          frame->nodes.insert(new_arg.enter);
          for (Edge const* e : arg.enter->in_edges()) {
            graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
                           e->IsControlEdge() ? Graph::kControlSlot : 0);
          }
          Node* dst = edges[i]->dst();
          int dst_input = edges[i]->dst_input();
          graph->RemoveEdge(edges[i]);
          graph->AddEdge(new_arg.enter, 0, dst, dst_input);
        }
        args.push_back(new_arg);
      }
    }
  }
  frame->args = std::move(args);

  // Order the arguments so that:
  // a) resource variables are last, and
  // b) sort lexicographically by name (for deterministic output).
  std::sort(frame->args.begin(), frame->args.end(),
            [](const Arg& a, const Arg& b) {
              bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE);
              bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE);
              return std::tie(a_is_resource, a.enter->name()) <
                     std::tie(b_is_resource, b.enter->name());
            });

  if (frame->loop_cond == nullptr) {
    return errors::InvalidArgument("Loop ", frame->name,
                                   " has no LoopCond node");
  }

  // Find the set of Switch nodes that are successors of the LoopCond.
  std::unordered_set<Node*> switches;
  for (const Edge* edge : frame->loop_cond->out_edges()) {
    if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
        edge->dst_input() == 1) {
      switches.insert(edge->dst());
    }
  }

  // For each non-constant argument, looks for the following pattern of nodes:
  // Enter ----> Merge  -------->  Switch  --> Exit
  //               ^                  ^
  //               |                  |
  //         NextIteration         LoopCond
  //               ^                  ^
  //               |                  |
  //              ...                ...
  for (Arg& arg : frame->args) {
    if (!arg.is_loop_invariant) {
      // Follow the edge from the Enter to Merge.
      const Edge* enter_merge = nullptr;
      for (const Edge* e : arg.enter->out_edges()) {
        // Ignore control-edges to the sink node. These are allowed by the
        // graph invariants, although probably they should have been stripped
        // off earlier.
        if (e->IsControlEdge() && e->dst()->IsSink()) {
          continue;
        }
        if (enter_merge != nullptr) {
          return errors::Internal(
              "Enter node for loop-varying argument ", arg.enter->name(),
              " has multiple successors: ", enter_merge->dst()->name(), " and ",
              e->dst()->name());
        }
        enter_merge = e;
      }
      if (enter_merge == nullptr) {
        return errors::Internal("Enter node for loop-varying argument ",
                                arg.enter->name(), " has zero successors");
      }
      arg.merge = enter_merge->dst();
      if (!IsMerge(arg.merge)) {
        return errors::InvalidArgument(
            "Successor of Enter node for loop-varying argument ",
            arg.merge->name(),
            " is not a Merge node; got: ", arg.merge->type_string());
      }

      // Find the NextIteration from the merge. There should be two inputs to
      // the Merge and the NextIteration should be the other input.
      if (arg.merge->input_types().size() != 2) {
        return errors::InvalidArgument(
            "Unexpected number of inputs to Merge node for loop-varying "
            "argument ",
            arg.merge->name(), "; expected 2, got ",
            arg.merge->input_types().size());
      }
      TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
                                               &arg.next_iteration));
      if (!IsNextIteration(arg.next_iteration)) {
        return errors::InvalidArgument(
            "Expected NextIteration node as input to Merge node; got node ",
            arg.next_iteration->name(), " with kind ",
            arg.next_iteration->type_string());
      }

      // Find the Switch successor of the Merge. There should be exactly one
      // Switch node that is a successor of both the Merge and the LoopCond.
      for (const Edge* edge : arg.merge->out_edges()) {
        if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
            switches.find(edge->dst()) != switches.end()) {
          if (arg.switch_node != nullptr) {
            return errors::InvalidArgument("Duplicate Switch successors to ",
                                           arg.merge->name());
          }
          arg.switch_node = edge->dst();
        }
      }
      if (arg.switch_node == nullptr) {
        return errors::InvalidArgument("Missing Switch successor to ",
                                       arg.merge->name());
      }

      // Update the device on the Identity outputs of the switch to match their
      // target. These Identity outputs do not

      // Loop over the switch node's output to:
      // - Find the Exit successor.
      // - Set the sharding on all Identity outputs of the switch. These
      //   identity nodes are values used by the loop body or condition.
      //   The Identity node may have the wrong device so copy the device from
      //   one of its outputs instead.
      for (const Edge* edge : arg.switch_node->out_edges()) {
        if (edge->src_output() == 0 && IsExit(edge->dst())) {
          if (arg.exit != nullptr) {
            return errors::InvalidArgument("Duplicate Exit successors to ",
                                           arg.switch_node->name());
          }
          arg.exit = edge->dst();
        } else if (StringPiece(edge->dst()->type_string()) == "Identity") {
          TF_RETURN_IF_ERROR(
              SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
        }
      }
    }
  }

  // Builds the condition and body functions.
  std::unique_ptr<Graph> cond_graph;
  TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
  DataTypeVector arg_types;
  std::unique_ptr<Graph> body_graph;
  TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));

  VLOG(2) << "Frame " << frame->name << " condition: "
          << dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
          << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);

  static std::atomic<int64> sequence_num(0LL);
  int64 id = ++sequence_num;
  NameAttrList cond_name;
  cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
  NameAttrList body_name;
  body_name.set_name(strings::StrCat("_functionalize_body_", id));
  FunctionDef cond_fdef;
  TF_RETURN_IF_ERROR(
      GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
  FunctionDef body_fdef;
  TF_RETURN_IF_ERROR(
      GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));

  TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
  TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));

  // Builds a While operator.
  NodeDef while_def;
  NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
  builder.Attr("T", arg_types);
  builder.Attr("cond", cond_name);
  builder.Attr("body", body_name);
  std::vector<NodeDefBuilder::NodeOut> inputs;
  for (int i = 0; i < frame->args.size(); ++i) {
    const Arg& arg = frame->args[i];
    const Edge* in_edge;
    TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
    if (in_edge->IsControlEdge()) {
      builder.ControlInput(in_edge->src()->name());
    } else {
      inputs.push_back(NodeDefBuilder::NodeOut(
          in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
    }
  }
  builder.Input(inputs);
  TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
  TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph));

  // Copies edges to the Enter nodes and from the Exit nodes onto the While.
  for (int i = 0; i < frame->args.size(); ++i) {
    const Arg& arg = frame->args[i];
    const Edge* in_edge;
    TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
    if (in_edge->IsControlEdge()) {
      graph->AddControlEdge(in_edge->src(), while_node);
    } else {
      graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
    }

    if (!arg.is_loop_invariant) {
      // Add output edges if the output of the loop is consumed.
      if (arg.exit != nullptr) {
        std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
                                       arg.exit->out_edges().end());
        for (const Edge* edge : edges) {
          Node* dst = edge->dst();
          int dst_input = edge->dst_input();
          graph->RemoveEdge(edge);

          if (dst_input == Graph::kControlSlot) {
            graph->AddControlEdge(while_node, dst);
          } else {
            graph->AddEdge(while_node, i, dst, dst_input);
          }
        }
      }
    }
  }

  // Remove the old nodes from the graph, and add the while node to the parent
  // frame.
  for (Node* node : frame->nodes) {
    graph->RemoveNode(node);
  }
  frame->nodes.clear();
  frame->parent->nodes.insert(while_node);

  VLOG(2) << "Frame " << frame->name << " after: "
          << dump_graph::DumpGraphToFile("functionalize_after", *graph);

  return Status::OK();
}

class FunctionalizeCond {
 public:
  // All nodes are assumed to be either in no branch, then branch, else branch,
  // or both branches (such as merge nodes).
  enum Branch {
    kElseBranch = 0,
    kThenBranch = 1,
    kBoth = 2,
    kNeither = 3,
    kNumBranchTypes = 4
  };

  // Returns a textual representation of the Branch b.
  static string Branch_Name(FunctionalizeCond::Branch b);

  // Comparison function used for sorting nodes consistently.
  struct CondCmp {
    bool operator()(const Node* lhs, const Node* rhs) const {
      bool lhs_is_resource =
          lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
      bool rhs_is_resource =
          rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
      return std::tie(lhs_is_resource, lhs->name()) <
             std::tie(rhs_is_resource, rhs->name());
    }
  };

  // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
  // nodes. That is, attempt to transform every remaining switch and merge nodes
  // in the graph into XlaIf nodes.
  // Precondition: All while loops have been removed from graph.
  static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library);

 private:
  struct ForwardFlowNode {
    explicit ForwardFlowNode(Branch branch = Branch::kNeither)
        : branch(branch), count(0) {}
    string ToString() const {
      return strings::StrCat("branch=", Branch_Name(branch), " count=", count);
    }
    Branch branch;
    int count;
  };

  FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
      : library_(library), graph_(graph) {}

  // Perform the actual cond functionalization. Iterate over groups of switch
  // nodes (linked by common predicate), from innermost to outermost, and
  // extract into XlaIf nodes.
  Status FunctionalizeInternal();

  // Converts a Merge node to a XlaIf. This encapsulates the process of
  // extracting the bodies needed for the then and else branch, creates a XlaIf
  // node, removing the nodes of the branches from the graph and replacing the
  // merge node with a XlaIf.
  Status ConvertCorrespondingMergeToXlaIf(
      const std::vector<Node*>& switch_nodes,
      const std::vector<Node*>& merge_nodes, Node* predicate);

  // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
  xla::StatusOr<Node*> BuildAndAddXlaIfOp(
      const std::vector<Node*>& switch_nodes,
      const std::vector<Node*>& merge_nodes, Node* predicate);

  // Extracts a function body corresponding to the given input edge of the merge
  // node.
  Status ExtractBody(const std::vector<Node*>& switch_nodes,
                     const std::vector<Node*>& merge_nodes, int input_edge,
                     Graph* body);

  // Adds all the input edges to `if_node` corresponding to the arguments.
  Status AddInputEdges(const std::vector<Node*>& cond_args, Node* predicate,
                       Node* if_node);

  // Adds all output edges from the `if_node`.
  Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);

  // Returns the switches of graph_ in postorder. Dead switch nodes are skipped
  // and removed from the graph.
  std::vector<Node*> DetermineSwitchOrder();

  // Update the state for destination based on the state of source and the node
  // being updated.
  Status Join(const ForwardFlowNode& src_state, const Node* dst,
              ForwardFlowNode* dst_state);

  // Validates that the branch_map and frontier of nodes for the conditional
  // section are as expected.
  Status ValidBranchMapAndFrontier(
      const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
      const std::unordered_set<Node*>& frontier);

  FunctionLibraryDefinition* library_;
  Graph* graph_;
};

bool IsDeadSwitch(const Node* node) {
  for (const Edge* e : node->out_edges()) {
    const Node* dst = e->dst();
    if (!dst->IsIdentity()) {
      return false;
    }
    for (const Edge* ee : dst->out_edges()) {
      if (!ee->IsControlEdge() || !ee->dst()->IsSink()) {
        return false;
      }
    }
  }
  return true;
}

string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) {
  const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = {
      "else", "then", "both", "neither", "count"};
  return branch_name[b];
}

Status FunctionalizeCond::ValidBranchMapAndFrontier(
    const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
        branch_map,
    const std::unordered_set<Node*>& frontier) {
  std::unordered_set<const Node*> pending[kNumBranchTypes];
  for (const auto& kv : branch_map) {
    if (kv.second.count != kv.first->in_edges().size()) {
      return errors::FailedPrecondition("Value ", kv.first->DebugString(),
                                        " not dominated by switch nodes.");
    }
    if (VLOG_IS_ON(1)) {
      // Append attribute to the graph if running with logging to make the
      // changes clearer in the visualization.
      kv.first->AddAttr("_XlaFunctionalizeBranch",
                        Branch_Name(kv.second.branch));
    }
  }
  for (Node* n : frontier) {
    pending[branch_map.at(n).branch].insert(n);
  }
  TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]);
  for (const Node* n : pending[kBoth]) {
    TF_RET_CHECK(IsMerge(n)) << n->DebugString();
    // Merge nodes may be in then or else branch too
  }
  int index = (pending[kThenBranch].size() <= pending[kElseBranch].size())
                  ? kThenBranch
                  : kElseBranch;
  int other = 1 - index;
  for (const Node* n : pending[index]) {
    if (pending[other].find(n) != pending[other].end()) {
      return errors::Internal(
          "Node (", n->DebugString().c_str(),
          ") in both Else and Then branch should be in Both.");
    }
  }
  return Status::OK();
}

Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
                               const Node* dst, ForwardFlowNode* dst_state) {
  TF_RET_CHECK(dst_state->branch != Branch::kBoth &&
               dst_state->branch != Branch::kNumBranchTypes)
      << "Unexpected/Invalid branch type: Merging "
      << Branch_Name(src_state.branch) << " with "
      << Branch_Name(dst_state->branch);
  if (dst_state->branch == Branch::kNeither) {
    dst_state->branch = src_state.branch;
  } else if (src_state.branch != dst_state->branch &&
             src_state.branch != Branch::kNeither) {
    if (IsMerge(dst)) {
      dst_state->branch = Branch::kBoth;
    } else {
      return errors::Internal("Illegal merge: ", src_state.ToString(), " with ",
                              dst_state->ToString(), " for ",
                              dst->DebugString());
    }
  }
  ++dst_state->count;
  return Status::OK();
}

std::vector<Node*> FunctionalizeCond::DetermineSwitchOrder() {
  std::vector<Node*> dead_switches;
  std::vector<Node*> switch_order;
  DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
    if (IsSwitch(n)) {
      if (IsDeadSwitch(n)) {
        dead_switches.push_back(n);
      } else {
        switch_order.push_back(n);
      }
    }
  });

  // Remove all dead switch nodes.
  for (Node* n : dead_switches) {
    graph_->RemoveNode(n);
  }

  return switch_order;
}

Status FunctionalizeCond::FunctionalizeInternal() {
  std::vector<Node*> switch_order = DetermineSwitchOrder();
  // If there are no switch nodes, then terminate.
  if (switch_order.empty()) {
    return Status::OK();
  }

  struct PredicateSwitches {
    explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}

    Node* predicate;
    std::vector<Node*> switches;
  };

  // Merge Switch nodes with common predicate.
  std::vector<PredicateSwitches> predicate_switch_order;
  std::unordered_map<Node*, int> predicate_index;
  // The nodes in switch_order are in reverse topological order, but the
  // clustered switches need not be (i.e., when considered as a cluster one
  // element of a cluster may be later in the topological order than another
  // node whose cluster is later in the topological order of clustered
  // switches).
  for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
    Node* pred;
    TF_CHECK_OK((*it)->input_node(1, &pred));
    if (predicate_index.find(pred) == predicate_index.end()) {
      predicate_index[pred] = predicate_switch_order.size();
      predicate_switch_order.emplace_back(pred);
    }
    predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
  }

  // Iterate from innermost set of clustered switches to outermost, replacing
  // matching switch->merge subgraphs with single XlaIf nodes.
  for (auto it = predicate_switch_order.rbegin();
       it != predicate_switch_order.rend(); ++it) {
    auto& ps = *it;
    VLOG(3) << "Flow down from: " << ps.predicate->name() << " -> "
            << NodesToString(ps.switches);

    std::unordered_map<Node*, ForwardFlowNode> branch_map;
    std::unordered_set<Node*> frontier;

    std::vector<Node*> stack = ps.switches;
    std::vector<bool> visited(graph_->num_node_ids(), false);
    while (!stack.empty()) {
      Node* n = stack.back();
      stack.pop_back();

      if (visited[n->id()]) {
        continue;
      }
      visited[n->id()] = true;

      // Propagate branch state along each edge of a switch node.
      bool sink_only = true;
      for (const Edge* e : n->out_edges()) {
        Node* out = e->dst();
        if (!out->IsOp()) {
          continue;
        }
        sink_only = false;
        // Propagate branch information.
        ForwardFlowNode& ffn = branch_map[out];
        if (IsSwitch(n)) {
          int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
          TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
        } else {
          TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
        }
        if (IsMerge(out)) {
          if (out->in_edges().size() == ffn.count) {
            frontier.insert(out);
          }
        } else if (!visited[out->id()] && ffn.count == out->in_edges().size()) {
          // If all predecessors are dominated by the switch nodes, then add
          // the output to the stack.
          stack.push_back(out);
        }
      }
      if (sink_only) {
        if (!IsIdentity(n)) {
          VLOG(1) << "Feeding into sink: " << n->DebugString();
        }
      }
    }

    TF_RETURN_IF_ERROR(ValidBranchMapAndFrontier(branch_map, frontier));
    VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
            << dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
    std::vector<Node*> switch_nodes(ps.switches);
    std::sort(switch_nodes.begin(), switch_nodes.end(), CondCmp());
    std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
    std::sort(merge_nodes.begin(), merge_nodes.end(), CondCmp());
    TF_RETURN_IF_ERROR(ConvertCorrespondingMergeToXlaIf(
        switch_nodes, merge_nodes, ps.predicate));
    for (auto& del_kv : branch_map) {
      graph_->RemoveNode(del_kv.first);
    }
    for (Node* node : switch_nodes) {
      graph_->RemoveNode(node);
    }
    VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): "
            << dump_graph::DumpGraphToFile("functionalize_ac", *graph_);
  }
  return Status::OK();
}

xla::StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
    const std::vector<Node*>& switch_nodes,
    const std::vector<Node*>& merge_nodes, Node* predicate) {
  VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
          << NodesToString(switch_nodes);

  NodeDef if_def;
  // Create a new If node using the name of the merge node.
  NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf");
  string branch[] = {"else_branch", "then_branch"};
  for (int i = 0; i < 2; ++i) {
    static std::atomic<int64> sequence_num(0LL);
    int64 id = ++sequence_num;

    NameAttrList body_name;
    body_name.set_name(
        strings::StrCat("_functionalize_if_", branch[i], "_", id));
    auto body = xla::MakeUnique<Graph>(graph_->op_registry());
    TF_RETURN_IF_ERROR(ExtractBody(switch_nodes, merge_nodes, i, body.get()));
    VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
    FunctionDef body_fdef;
    TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
    TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
    builder.Attr(branch[i], body_name);
  }

  // Build input type.
  std::vector<NodeDefBuilder::NodeOut> inputs;
  DataTypeVector in_arg_types;
  for (const Node* arg : switch_nodes) {
    const Edge* in_edge;
    TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
    if (in_edge->IsControlEdge()) {
      builder.ControlInput(in_edge->src()->name());
    } else {
      DataType dtype = arg->input_type(0);
      inputs.emplace_back(NodeDefBuilder::NodeOut(
          in_edge->src()->name(), in_edge->src_output(), dtype));
      in_arg_types.push_back(dtype);
    }
  }
  builder.Attr("Tin", in_arg_types);

  // Build output type.
  DataTypeVector out_type;
  for (const Node* merge : merge_nodes) {
    DataType dtype = merge->output_type(0);
    out_type.push_back(dtype);
  }
  builder.Attr("Tout", out_type);

  builder.Attr("Tcond", DT_BOOL);
  builder.Device(predicate->assigned_device_name());
  // Conditional should be the first input ...
  builder.Input(
      NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0)));
  // ... followed by the other inputs.
  builder.Input(inputs);

  TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
  TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_));
  return if_node;
}

Status FunctionalizeCond::ExtractBody(const std::vector<Node*>& switch_nodes,
                                      const std::vector<Node*>& merge_nodes,
                                      int input_edge, Graph* body) {
  VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
          << input_edge;
  std::vector<bool> squash_src_outputs(graph_->num_node_ids(), false);
  std::vector<Node*> node_map(graph_->num_node_ids(), nullptr);
  int arg_count = 0;
  for (const auto* arg : switch_nodes) {
    DataType dtype = arg->input_type(0);
    TF_ASSIGN_OR_RETURN(Node * arg_node,
                        BuildArgNode(body, dtype, arg_count++));
    node_map.at(arg->id()) = arg_node;
    squash_src_outputs.at(arg->id()) = true;
  }

  std::vector<Node*> stack;
  stack.reserve(switch_nodes.size());
  for (int j = 0; j < merge_nodes.size(); ++j) {
    Node* node = merge_nodes[j];
    TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
                        BuildRetvalNode(body, node->output_type(0),
                                        /*index=*/j));
    const Edge* in_edge;
    TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge));
    Node* in = in_edge->src();
    if (node_map.at(in->id()) == nullptr) {
      node_map.at(in->id()) = body->CopyNode(in);
    }

    if (std::find(switch_nodes.begin(), switch_nodes.end(), in) ==
        switch_nodes.end()) {
      body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
                    node_map.at(node->id()), 0);
    } else {
      body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0);
      // Don't include input nodes that are already just returned in stack.
      continue;
    }
    stack.push_back(in);
  }

  return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map,
                      body);
}

Status FunctionalizeCond::AddInputEdges(const std::vector<Node*>& cond_args,
                                        Node* predicate, Node* if_node) {
  VLOG(3) << "AddInputEdges for " << if_node->name();
  int i = 0;
  graph_->AddEdge(predicate, 0, if_node, i++);
  for (const Node* arg : cond_args) {
    const Edge* in_edge;
    TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
    if (in_edge->IsControlEdge()) {
      graph_->AddControlEdge(in_edge->src(), if_node);
    } else {
      graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++);
    }
  }
  return Status::OK();
}

Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
                                         Node* if_node) {
  VLOG(3) << "AddOutputEdges for " << if_node->name();
  for (int i = 0; i < outputs.size(); ++i) {
    Node* node = outputs[i];
    std::vector<const Edge*> edges(node->out_edges().begin(),
                                   node->out_edges().end());
    for (const Edge* edge : edges) {
      Node* dst = edge->dst();
      int dst_input = edge->dst_input();

      if (edge->src_output() > 0) {
        return errors::Unimplemented("Output of index (", edge->src_output(),
                                     ") of merge node ", node->name());
      }
      graph_->RemoveEdge(edge);

      int src_output =
          dst_input == Graph::kControlSlot ? Graph::kControlSlot : i;
      graph_->AddEdge(if_node, src_output, dst, dst_input);
    }
  }
  return Status::OK();
}

Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf(
    const std::vector<Node*>& switch_nodes,
    const std::vector<Node*>& merge_nodes, Node* predicate) {
  VLOG(1) << "ConvertMergeToXlaIf for " << NodesToString(switch_nodes) << " -> "
          << NodesToString(merge_nodes);

  // Extract bodies and builds a If operator.
  TF_ASSIGN_OR_RETURN(Node * if_node,
                      BuildAndAddXlaIfOp(switch_nodes, merge_nodes, predicate));
  TF_RETURN_IF_ERROR(AddInputEdges(switch_nodes, predicate, if_node));
  TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));

  return Status::OK();
}

Status FunctionalizeCond::Functionalize(Graph* graph,
                                        FunctionLibraryDefinition* library) {
  VLOG(1) << "FunctionalizeCond::Functionalize";
  FunctionalizeCond fc(graph, library);
  return fc.FunctionalizeInternal();
}

}  // namespace

// Transformation that converts Tensorflow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
                                FunctionLibraryDefinition* library) {
  VLOG(2) << "FunctionalizeControlFlow (initial): "
          << dump_graph::DumpGraphToFile("functionalize_initial", *graph);
  // Note: BuildControlFlowInfo() requires that the graph's source node is
  // connected to all source nodes in the graph. Many graphs violate this
  // invariant.
  std::vector<ControlFlowInfo> cf_info;
  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));

  // Builds Frames, indexed by name.
  std::unordered_map<string, Frame> frames;
  for (Node* node : graph->op_nodes()) {
    const ControlFlowInfo& cf = cf_info[node->id()];

    VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
            << " frame: " << (cf.frame ? cf.frame->name() : "---")
            << " parent_frame: "
            << (cf.parent_frame ? cf.parent_frame->name() : "---");
    TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr);

    Frame& frame = frames[cf.frame_name];
    Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
    if (frame.parent == nullptr) {
      frame.parent = parent;
      frame.name = cf.frame_name;
      ++parent->num_children;
    } else if (frame.parent != parent) {
      return errors::InvalidArgument("Mismatched parent frames for ",
                                     cf.frame->id(), ": ", parent->name, " vs ",
                                     frame.parent->name);
    }

    if (IsEnter(node)) {
      Arg arg;
      arg.enter = node;
      TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant",
                                     &arg.is_loop_invariant));
      frame.args.push_back(arg);
    } else if (IsLoopCond(node)) {
      if (frame.loop_cond) {
        return errors::InvalidArgument(
            "Loop ", cf.frame_name,
            " has more than one LoopCond node: ", node->name(), " and ",
            frame.loop_cond->name());
      }
      frame.loop_cond = node;
    }
    frame.nodes.insert(node);
  }

  // Adds frames with no children (i.e., the innermost frames) to a worklist.
  std::deque<Frame*> worklist;
  for (auto& frame : frames) {
    if (frame.second.num_children == 0) {
      worklist.push_back(&frame.second);
    }
  }

  // Eliminate loops from innermost to outermost.
  while (!worklist.empty()) {
    Frame* frame = worklist.front();
    worklist.pop_front();
    if (frame->parent == frame) {
      // Skip the root frame.
      continue;
    }

    TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library));

    // If the parent has no remaining children, add it to the worklist.
    --frame->parent->num_children;
    if (frame->parent->num_children == 0) {
      worklist.push_back(frame->parent);
    }
  }

  // FunctionalizeControlFlow is invoked for every function, so the loops's
  // bodies and conditionals that were extracted into functions will be handled
  // in successive invocations.
  TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));

  VLOG(2) << "FunctionalizeControlFlow (final): "
          << dump_graph::DumpGraphToFile("functionalize_final", *graph);
  return Status::OK();
}

}  // namespace tensorflow
