#include "graph_opt.h"

using namespace tensorflow;
using namespace tensorflow::grappler;

namespace annc {
void update_node_indexes(const GraphDef* graph,
                         std::unordered_map<std::string, int>& node_indexes) {
  for (int i = 0; i < graph->node_size(); ++i) {
    node_indexes[graph->node(i).name()] = i;
  }
}

void GraphOptimizer::register_rewriter(
    std::unique_ptr<PatternRewriter> rewriter) {
  rewriters_.push_back(std::move(rewriter));
}

void GraphOptimizer::optimize() {
  update_node_indexes(graph_, node_indexes_);
  int node_index = 0;
  const int node_size = graph_->node_size();
  while (node_index < node_size) {
    const NodeDef& node = graph_->node(node_index);
    const std::string& node_name = node.name();
    for (auto& rewriter : rewriters_) {
      if (rewriter->match_and_rewrite(&node, graph_, node_indexes_)) {
        update_node_indexes(graph_, node_indexes_);
        const std::string new_node_name = node_name + fusion_appendix;
        node_index = node_indexes_.at(new_node_name);
        break;
      }
    }
    node_index++;
  }
}

std::string get_node_name(const std::string& name) {
  size_t colon_pos = name.find_last_of(':');
  std::string node_name = name;
  if (colon_pos != std::string::npos) {
    node_name = name.substr(0, colon_pos);
  }
  return node_name;
}

void set_fusedop_attributes(NodeDef* fused,
                            const absl::Span<const absl::string_view> fused_ops,
                            int num_args = 1, float epsilon = 0.0) {
  auto* attr = fused->mutable_attr();
  SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
  SetAttrValue(num_args, &(*attr)["num_args"]);
  SetAttrValue(epsilon, &(*attr)["epsilon"]);  // required only for BatchNorm
}

const NodeDef* PatternRewriter::get_node(const std::string& name) {
  const std::string node_name = get_node_name(name);
  const int node_index = indexes_->at(node_name);
  return &graph_->node(node_index);
}

NodeDef* PatternRewriter::get_mutable_node(const std::string& name) {
  const std::string node_name = get_node_name(name);
  if (indexes_->find(node_name) == indexes_->end()) return nullptr;
  const int node_index = indexes_->at(node_name);
  return graph_->mutable_node(node_index);
}

NodeDef* PatternRewriter::get_operand(const NodeDef* node,
                                      std::string op_type) {
  for (int i = 0; i < node->input_size(); ++i) {
    NodeDef* operand = get_mutable_node(node->input(i));
    if (operand != nullptr && operand->op() == op_type) return operand;
  }
  return nullptr;
}

const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index,
                                         const std::string& op_type) {
  std::string node_name = node->name();
  if (index) std::string node_name = node_name + ":" + std::to_string(index);
  for (int i = 0; i < graph_->node_size(); ++i) {
    const NodeDef* node = graph_->mutable_node(i);
    for (int j = 0; j < node->input_size(); ++j) {
      if (node->input(j) == node_name && node->op() == op_type) {
        return node;
      }
    }
  }
  return nullptr;
}

void PatternRewriter::replace_all_users_with(const NodeDef* old_node,
                                             int old_index,
                                             const NodeDef* new_node,
                                             int new_index, GraphDef* graph) {
  std::string old_name = old_node->name();
  if (old_index) old_name = old_name + ":" + std::to_string(old_index);
  std::string new_name = new_node->name();
  if (new_index) new_name = new_name + ":" + std::to_string(new_index);
  for (int i = 0; i < graph->node_size(); ++i) {
    NodeDef* node = graph->mutable_node(i);
    for (int j = 0; j < node->input_size(); ++j) {
      if (node->input(j) == old_name) {
        node->set_input(j, new_name);
      }
    }
  }
}

class KPFusedSparseDynamicStitchRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedSparseDynamicStitch"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" &&
                  node->input_size() % 2 == 0)
    int num_inputs = node->input_size();
    int num_partitions = num_inputs / 2;
    // left branch
    const NodeDef* partition = get_node(node->input(0));
    CHECK_NODE_OK(partition->op() == "DynamicPartition" &&
                  partition->input_size() == 2)
    const NodeDef* range = get_node(partition->input(0));
    CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
    // Range start=0, delta=1
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(range->input(0)), {0}))
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(range->input(2)), {1}))
    const NodeDef* size = get_node(range->input(1));
    CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1)
    const NodeDef* cast = get_node(partition->input(1));
    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
    const NodeDef* floor_mod = get_node(cast->input(0));
    CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2)
    CHECK_NODE_OK(check_const_value<int64_t>(
        get_mutable_node(floor_mod->input(1)), {num_partitions}))

    CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions}))
    CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions}))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(size->input(0));
    // right branch
    for (int i = num_partitions; i < num_inputs; ++i) {
      const NodeDef* gather = get_node(node->input(i));
      CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
      // Gather axis=0
      CHECK_NODE_OK(
          check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
      const NodeDef* partition_1 = get_node(gather->input(1));
      CHECK_NODE_OK(partition_1->op() == "DynamicPartition" &&
                    partition_1->input_size() == 2)
      CHECK_NODE_OK(
          check_int_attr(partition_1, "num_partitions", {num_partitions}))
      const NodeDef* floor_div = get_node(partition_1->input(0));
      CHECK_NODE_OK(floor_div->op() == "FloorDiv" &&
                    floor_div->input_size() == 2)
      CHECK_NODE_OK(check_const_value<int64_t>(
          get_mutable_node(floor_div->input(1)), {num_partitions}))
      fused_node->add_input(gather->input(0));
    }
    (*fused_node->mutable_attr())["N"].set_i(num_partitions);
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(node, 0, fused_node, 0, graph);
    return true;
  }
};

class KPFusedSparseSegmentReduceRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedSparseSegmentReduce"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
    const NodeDef* shape = get_node(node->input(0));
    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
    const NodeDef* ss_reduce = get_node(shape->input(0));
    CHECK_NODE_OK(ss_reduce->input_size() == 3)
    AttrValue combiner;
    if (ss_reduce->op() == "SparseSegmentMean")
      combiner.set_i(1);
    else if (ss_reduce->op() == "SparseSegmentSum")
      combiner.set_i(0);
    else
      return false;
    const NodeDef* cast = get_node(ss_reduce->input(2));
    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
    const NodeDef* strided_slice = get_node(cast->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
                  strided_slice->input_size() == 4)

    // check fusion conditions
    CHECK_NODE_OK(
        check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
    CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1}))
    CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(ss_reduce->input(0));
    fused_node->add_input(ss_reduce->input(1));
    fused_node->add_input(strided_slice->input(0));
    fused_node->add_input(strided_slice->input(1));
    fused_node->add_input(node->input(1));
    AddNodeAttr("combiner", combiner, fused_node);

    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(ss_reduce, 0, fused_node, 0, graph);
    replace_all_users_with(node, 0, fused_node, 1, graph);
    return true;
  }
};

class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter {
 public:
  std::string name() const override {
    return "KPFusedSparseSegmentReduceNonzero";
  }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(node->op() == "GatherND" &&
                  node->input_size() == 2)  // output:2
    const NodeDef* ss_reduce = get_node(node->input(0));
    CHECK_NODE_OK(ss_reduce->input_size() == 3)
    AttrValue combiner;
    if (ss_reduce->op() == "SparseSegmentMean") {
      combiner.set_i(1);
    } else if (ss_reduce->op() == "SparseSegmentSum") {
      combiner.set_i(0);
    } else {
      return false;
    }
    const NodeDef* where = get_node(node->input(1));
    CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1)
    const NodeDef* cast = get_user(where, 0, "Cast");
    CHECK_NODE_OK(cast != nullptr)  // output: 1
    const NodeDef* notequal = get_node(where->input(0));
    CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2);
    const NodeDef* zerolike = get_node(notequal->input(1));
    CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1)
    const NodeDef* cast_1 = get_node(ss_reduce->input(2));
    CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
    const NodeDef* strided_slice = get_node(cast->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
                  strided_slice->input_size() == 4)
    const NodeDef* shape = get_user(ss_reduce, 0, "Shape");
    CHECK_NODE_OK(shape != nullptr)
    const NodeDef* cast_2 = get_user(shape, 0, "Cast");  // output: 0
    CHECK_NODE_OK(cast_2 != nullptr)

    CHECK_NODE_OK(
        check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(ss_reduce->input(0));
    fused_node->add_input(ss_reduce->input(1));
    fused_node->add_input(strided_slice->input(0));
    fused_node->add_input(strided_slice->input(1));
    AddNodeAttr("combiner", combiner, fused_node);

    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name() << "\n";
    replace_all_users_with(cast_2, 0, fused_node, 0, graph);
    replace_all_users_with(cast, 0, fused_node, 1, graph);
    replace_all_users_with(node, 0, fused_node, 2, graph);
    return true;
  }
};

class KPFusedEmbeddingPaddingRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedEmbeddingPadding"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(IsReshape(*node))
    const NodeDef* user = get_user(node, 0, "ConcatV2");
    CHECK_NODE_OK(user != nullptr)
    const NodeDef* concat = get_node(node->input(0));
    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(concat->input(2)), {0}))
    const NodeDef* fill = get_node(concat->input(1));
    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
    NodeDef* pack = get_operand(fill, "Pack");
    CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2)
    NodeDef* fill_const = get_operand(fill, "Const");
    CHECK_NODE_OK(fill_const != nullptr &&
                  check_const_value<int>(fill_const, {0}))
    NodeDef* sub = get_operand(pack, "Sub");
    CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2)
    const NodeDef* strided_slice = get_node(sub->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
                  strided_slice->input_size() == 4)
    const NodeDef* cast = get_node(strided_slice->input(0));
    CHECK_NODE_OK(IsCast(*cast))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(cast->input(0));
    fused_node->add_input(concat->input(0));
    fused_node->add_input(sub->input(1));
    fused_node->add_input(node->input(1));
    const NodeDef* pack_left = get_node(pack->input(0));
    const NodeDef* pack_right = get_node(pack->input(1));
    if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
      fused_node->add_input(pack->input(0));
    } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
      fused_node->add_input(pack->input(1));
    } else {
      return false;
    }

    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(sub, 0, fused_node, 0, graph);
    replace_all_users_with(node, 0, fused_node, 1, graph);
    return true;
  }
};

class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedEmbeddingPaddingFast"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
    CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {0}))
    CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(2)), {1}))
    CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(3)), {1}))
    CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
    const NodeDef* shape = get_node(node->input(0));
    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
    const NodeDef* reshape = get_node(shape->input(0));
    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
    const NodeDef* concat = get_node(reshape->input(0));
    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
    const NodeDef* fill = get_node(concat->input(1));
    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
    const NodeDef* pack = get_node(fill->input(0));
    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
    const NodeDef* sub = get_node(pack->input(0));
    CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2)
    const NodeDef* strided_slice = get_node(sub->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
                  strided_slice->input_size() == 4)
    const NodeDef* cast = get_node(strided_slice->input(0));
    CHECK_NODE_OK(IsCast(*cast))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(cast->input(0));
    fused_node->add_input(concat->input(0));
    fused_node->add_input(sub->input(1));
    fused_node->add_input(reshape->input(1));
    const NodeDef* pack_left = get_node(pack->input(0));
    const NodeDef* pack_right = get_node(pack->input(1));
    if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
      fused_node->add_input(pack->input(0));
    } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
      fused_node->add_input(pack->input(1));
    } else {
      return false;
    }
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(sub, 0, fused_node, 0, graph);
    replace_all_users_with(node, 0, fused_node, 1, graph);
    return true;
  }
};

class KPFusedSparseSelectRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedSparseSelect"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
    const NodeDef* select_0 = get_node(node->input(0));
    CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3)
    const NodeDef* select_1 = get_node(select_0->input(2));
    CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3)
    const NodeDef* fill = get_node(select_1->input(1));
    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<float>(get_mutable_node(fill->input(1)), {1.0f}))
    const NodeDef* cast = get_node(select_1->input(2));
    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
    const NodeDef* equal = get_node(select_1->input(0));
    CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2)
    NodeDef* reshape = get_operand(equal, "Reshape");
    CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape))
    const NodeDef* greater = get_node(cast->input(0));
    CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2)
    const NodeDef* reshape_4 = get_node(greater->input(0));
    CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(reshape_4->input(1)), {-1, 1}))
    const NodeDef* equal_1 = get_node(select_0->input(0));
    CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2)
    NodeDef* reshape_1 = get_operand(equal_1, "Reshape");
    CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1))
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))

    // right branch
    const NodeDef* select_2 = get_node(node->input(1));
    CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3)
    const NodeDef* equal_2 = get_node(select_2->input(0));
    CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2)
    const NodeDef* fill_1 = get_node(select_2->input(2));
    CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<float>(get_mutable_node(fill_1->input(1)), {1.0f}))
    const NodeDef* reshape_2 = get_operand(equal_2, "Reshape");
    CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2))
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(reshape_2->input(1)), {-1, 1}))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(reshape->input(0));
    fused_node->add_input(reshape_1->input(0));
    fused_node->add_input(reshape_2->input(0));
    std::vector<const NodeDef*> const_inputs = {equal, equal_1, equal_2,
                                                greater};
    for (const NodeDef* const_node : const_inputs) {
      const NodeDef* left = get_node(const_node->input(0));
      const NodeDef* right = get_node(const_node->input(1));
      if (IsConstant(*left) || IsHostConstant(*left)) {
        fused_node->add_input(const_node->input(0));
      } else if (IsConstant(*right) || IsHostConstant(*right)) {
        fused_node->add_input(const_node->input(1));
      } else {
        return false;
      }
    }
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(reshape, 0, fused_node, 0, graph);
    replace_all_users_with(select_0, 0, fused_node, 1, graph);
    replace_all_users_with(node, 0, fused_node, 2, graph);
    return true;
  }
};

class KPFusedGatherRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedGather"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3)
    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {0}))
    const NodeDef* gather = get_node(node->input(0));
    CHECK_NODE_OK(gather->op() == "GatherV2" &&
                  gather->input_size() == 3)  // input:0
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
    const NodeDef* unique = get_node(node->input(1));  // output:1
    CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1)
    const NodeDef* unique_1 = get_node(unique->input(0));
    CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1)
    const NodeDef* strided_slice = get_node(unique_1->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice))  // input:1 2
    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(gather->input(0));
    fused_node->add_input(strided_slice->input(0));
    fused_node->add_input(strided_slice->input(1));
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(unique_1, 0, fused_node, 0, graph);
    replace_all_users_with(unique_1, 1, fused_node, 1, graph);
    replace_all_users_with(node, 0, fused_node, 2, graph);
    return true;
  }
};

class KPFusedSparseReshapeRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedSparseReshape"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3)
    const NodeDef* concat = get_node(node->input(0));
    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(concat->input(2)), {-1}))
    const NodeDef* reshape = get_node(concat->input(1));
    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(reshape->input(1)), {-1, 1}))
    const NodeDef* strided_slice = get_node(reshape->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
                  strided_slice->input_size() == 4)
    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
    const NodeDef* cast_1 = get_node(concat->input(0));
    CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
    const NodeDef* reshape_1 = get_node(cast_1->input(0));
    CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
    const NodeDef* range = get_node(reshape_1->input(0));
    CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
    // Range start=0, delta=1
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(range->input(0)), {0}))
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(range->input(2)), {1}))
    const NodeDef* cast = get_node(node->input(1));
    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
    const NodeDef* pack = get_node(cast->input(0));
    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
    const NodeDef* strided_slice_1 = get_node(pack->input(0));
    CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) &&
                  strided_slice_1->input_size() == 4)
    CHECK_NODE_OK(check_const_value<int>(
        get_mutable_node(strided_slice_1->input(1)), {0}))
    CHECK_NODE_OK(check_const_value<int>(
        get_mutable_node(strided_slice_1->input(2)), {1}))
    CHECK_NODE_OK(check_const_value<int>(
        get_mutable_node(strided_slice_1->input(3)), {1}))
    CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1))
    const NodeDef* shape = get_node(strided_slice_1->input(0));
    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(shape->input(0));
    fused_node->add_input(strided_slice->input(1));
    fused_node->add_input(node->input(2));
    fused_node->add_input(pack->input(1));
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(node, 0, fused_node, 0, graph);
    replace_all_users_with(node, 1, fused_node, 1, graph);
    return true;
  }
};

class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedEmbeddingActionIdGather"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(node->input(2)), {-1}))
    const NodeDef* reshape = get_node(node->input(0));
    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
    const NodeDef* gather = get_node(reshape->input(0));
    const NodeDef* pack_1 = get_node(reshape->input(1));
    CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(pack_1->input(1)), {-1}))
    CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
    const NodeDef* gather_1 = get_node(gather->input(0));
    CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3)
    CHECK_NODE_OK(
        check_const_value<int>(get_mutable_node(gather_1->input(2)), {0}))
    const NodeDef* fill = get_node(node->input(1));
    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(fill->input(1)), {0}))
    const NodeDef* pack = get_node(fill->input(0));
    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(gather_1->input(1));
    fused_node->add_input(gather_1->input(0));
    fused_node->add_input(gather->input(1));
    fused_node->add_input(pack->input(0));
    fused_node->add_input(pack->input(1));
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(node, 0, fused_node, 0, graph);
    return true;
  }
};

class KPFusedBatchMatMulAddSigmoidRewriter : public PatternRewriter {
 public:
  std::string name() const override { return "KPFusedBatchMatMulAddSigmoid"; }

  bool match_and_rewrite(
      const NodeDef* node, GraphDef* graph,
      std::unordered_map<std::string, int>& node_indexes) override {
    graph_ = graph;
    indexes_ = &node_indexes;
    CHECK_NODE_OK(node->op() == "Sigmoid" && node->input_size() == 1)

    const NodeDef* add = get_node(node->input(0));
    CHECK_NODE_OK(add->op() == "AddV2" && add->input_size() == 2)

    int matmul_input_idx = -1;
    int bias_input_idx = -1;
    if (IsAnyBatchMatMul(*get_node(add->input(0)))) {
      matmul_input_idx = 0;
      bias_input_idx = 1;
    } else if (IsAnyBatchMatMul(*get_node(add->input(1)))) {
      matmul_input_idx = 1;
      bias_input_idx = 0;
    } else {
      return false;
    }

    const NodeDef* batch_matmul = get_node(add->input(matmul_input_idx));
    CHECK_NODE_OK(batch_matmul->input_size() == 2)

    auto nodes = graph->mutable_node();
    NodeDef* fused_node = nodes->Add();
    fused_node->set_name(node->name() + fusion_appendix);
    fused_node->set_op(name());
    fused_node->set_device(node->device());
    fused_node->add_input(batch_matmul->input(0));
    fused_node->add_input(batch_matmul->input(1));
    fused_node->add_input(add->input(bias_input_idx));
    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);

    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
            << fused_node->name();
    replace_all_users_with(node, 0, fused_node, 0, graph);
    return true;
  }
};

void run_graph_optimization(GraphDef* graph) {
  GraphOptimizer optimizer(graph);

  const char* annc_fused_all = getenv("ANNC_FUSED_ALL");
  const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH");
  const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE");
  const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING");
  const char* annc_fused_emb_padding_fast =
      getenv("ANNC_FUSED_EMD_PADDING_FAST");
  const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT");
  const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER");
  const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE");
  const char* annc_fused_emb_actionid_gather =
      getenv("ANNC_FUSED_EMB_ACTIONID_GATHER");
  const char* annc_fused_sps_reduce_nonzero =
      getenv("ANNC_FUSED_SPS_REDUCE_NONZERO");
  const char* annc_fused_batchmatmul_add_sigmoid =
      getenv("ANNC_FUSED_BATCHMATMUL_ADD_SIGMOID");

  bool enable_all =
      (annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0;

  // default enable all rewriters
  if (enable_all || (annc_fused_sps_stitch != nullptr &&
                      strcmp(annc_fused_sps_stitch, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedSparseDynamicStitchRewriter>());
  if (enable_all || (annc_fused_sps_reduce != nullptr &&
                      strcmp(annc_fused_sps_reduce, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedSparseSegmentReduceRewriter>());
  if (enable_all || (annc_fused_emb_padding_fast != nullptr &&
                      strcmp(annc_fused_emb_padding_fast, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedEmbeddingPaddingFastRewriter>());
  if (enable_all || (annc_fused_emb_padding != nullptr &&
                      strcmp(annc_fused_emb_padding, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedEmbeddingPaddingRewriter>());
  if (enable_all || (annc_fused_sps_select != nullptr &&
                      strcmp(annc_fused_sps_select, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedSparseSelectRewriter>());
  if (enable_all ||
      (annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0))
    optimizer.register_rewriter(std::make_unique<KPFusedGatherRewriter>());
  if (enable_all || (annc_fused_sps_reshape != nullptr &&
                      strcmp(annc_fused_sps_reshape, "1") == 0))
    optimizer.register_rewriter(std::make_unique<KPFusedSparseReshapeRewriter>());
  if (annc_fused_emb_actionid_gather != nullptr &&
      strcmp(annc_fused_emb_actionid_gather, "1") == 0)
    optimizer.register_rewriter(
        std::make_unique<KPFusedEmbeddingActionIdGatherRewriter>());
  if (annc_fused_sps_reduce_nonzero != nullptr &&
      strcmp(annc_fused_sps_reduce_nonzero, "1") == 0)
    optimizer.register_rewriter(
        std::make_unique<KPFusedSparseSegmentReduceNonzeroRewriter>());
  if (enable_all || (annc_fused_batchmatmul_add_sigmoid != nullptr &&
                     strcmp(annc_fused_batchmatmul_add_sigmoid, "1") == 0))
    optimizer.register_rewriter(
        std::make_unique<KPFusedBatchMatMulAddSigmoidRewriter>());
  optimizer.optimize();
}
}  // namespace annc