#include "graph_opt.h"
using namespace tensorflow;
using namespace tensorflow::grappler;
namespace annc {
void update_node_indexes(const GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) {
for (int i = 0; i < graph->node_size(); ++i) {
node_indexes[graph->node(i).name()] = i;
}
}
void GraphOptimizer::register_rewriter(
std::unique_ptr<PatternRewriter> rewriter) {
rewriters_.push_back(std::move(rewriter));
}
void GraphOptimizer::optimize() {
update_node_indexes(graph_, node_indexes_);
int node_index = 0;
const int node_size = graph_->node_size();
while (node_index < node_size) {
const NodeDef& node = graph_->node(node_index);
const std::string& node_name = node.name();
for (auto& rewriter : rewriters_) {
if (rewriter->match_and_rewrite(&node, graph_, node_indexes_)) {
update_node_indexes(graph_, node_indexes_);
const std::string new_node_name = node_name + fusion_appendix;
node_index = node_indexes_.at(new_node_name);
break;
}
}
node_index++;
}
}
std::string get_node_name(const std::string& name) {
size_t colon_pos = name.find_last_of(':');
std::string node_name = name;
if (colon_pos != std::string::npos) {
node_name = name.substr(0, colon_pos);
}
return node_name;
}
void set_fusedop_attributes(NodeDef* fused,
const absl::Span<const absl::string_view> fused_ops,
int num_args = 1, float epsilon = 0.0) {
auto* attr = fused->mutable_attr();
SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
SetAttrValue(num_args, &(*attr)["num_args"]);
SetAttrValue(epsilon, &(*attr)["epsilon"]);
}
const NodeDef* PatternRewriter::get_node(const std::string& name) {
const std::string node_name = get_node_name(name);
const int node_index = indexes_->at(node_name);
return &graph_->node(node_index);
}
NodeDef* PatternRewriter::get_mutable_node(const std::string& name) {
const std::string node_name = get_node_name(name);
if (indexes_->find(node_name) == indexes_->end()) return nullptr;
const int node_index = indexes_->at(node_name);
return graph_->mutable_node(node_index);
}
NodeDef* PatternRewriter::get_operand(const NodeDef* node,
std::string op_type) {
for (int i = 0; i < node->input_size(); ++i) {
NodeDef* operand = get_mutable_node(node->input(i));
if (operand != nullptr && operand->op() == op_type) return operand;
}
return nullptr;
}
const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index,
const std::string& op_type) {
std::string node_name = node->name();
if (index) std::string node_name = node_name + ":" + std::to_string(index);
for (int i = 0; i < graph_->node_size(); ++i) {
const NodeDef* node = graph_->mutable_node(i);
for (int j = 0; j < node->input_size(); ++j) {
if (node->input(j) == node_name && node->op() == op_type) {
return node;
}
}
}
return nullptr;
}
void PatternRewriter::replace_all_users_with(const NodeDef* old_node,
int old_index,
const NodeDef* new_node,
int new_index, GraphDef* graph) {
std::string old_name = old_node->name();
if (old_index) old_name = old_name + ":" + std::to_string(old_index);
std::string new_name = new_node->name();
if (new_index) new_name = new_name + ":" + std::to_string(new_index);
for (int i = 0; i < graph->node_size(); ++i) {
NodeDef* node = graph->mutable_node(i);
for (int j = 0; j < node->input_size(); ++j) {
if (node->input(j) == old_name) {
node->set_input(j, new_name);
}
}
}
}
class KPFusedSparseDynamicStitchRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedSparseDynamicStitch"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" &&
node->input_size() % 2 == 0)
int num_inputs = node->input_size();
int num_partitions = num_inputs / 2;
const NodeDef* partition = get_node(node->input(0));
CHECK_NODE_OK(partition->op() == "DynamicPartition" &&
partition->input_size() == 2)
const NodeDef* range = get_node(partition->input(0));
CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(range->input(0)), {0}))
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(range->input(2)), {1}))
const NodeDef* size = get_node(range->input(1));
CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1)
const NodeDef* cast = get_node(partition->input(1));
CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
const NodeDef* floor_mod = get_node(cast->input(0));
CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2)
CHECK_NODE_OK(check_const_value<int64_t>(
get_mutable_node(floor_mod->input(1)), {num_partitions}))
CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions}))
CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions}))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(size->input(0));
for (int i = num_partitions; i < num_inputs; ++i) {
const NodeDef* gather = get_node(node->input(i));
CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
const NodeDef* partition_1 = get_node(gather->input(1));
CHECK_NODE_OK(partition_1->op() == "DynamicPartition" &&
partition_1->input_size() == 2)
CHECK_NODE_OK(
check_int_attr(partition_1, "num_partitions", {num_partitions}))
const NodeDef* floor_div = get_node(partition_1->input(0));
CHECK_NODE_OK(floor_div->op() == "FloorDiv" &&
floor_div->input_size() == 2)
CHECK_NODE_OK(check_const_value<int64_t>(
get_mutable_node(floor_div->input(1)), {num_partitions}))
fused_node->add_input(gather->input(0));
}
(*fused_node->mutable_attr())["N"].set_i(num_partitions);
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(node, 0, fused_node, 0, graph);
return true;
}
};
class KPFusedSparseSegmentReduceRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedSparseSegmentReduce"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
const NodeDef* shape = get_node(node->input(0));
CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
const NodeDef* ss_reduce = get_node(shape->input(0));
CHECK_NODE_OK(ss_reduce->input_size() == 3)
AttrValue combiner;
if (ss_reduce->op() == "SparseSegmentMean")
combiner.set_i(1);
else if (ss_reduce->op() == "SparseSegmentSum")
combiner.set_i(0);
else
return false;
const NodeDef* cast = get_node(ss_reduce->input(2));
CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
const NodeDef* strided_slice = get_node(cast->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
strided_slice->input_size() == 4)
CHECK_NODE_OK(
check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1}))
CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(ss_reduce->input(0));
fused_node->add_input(ss_reduce->input(1));
fused_node->add_input(strided_slice->input(0));
fused_node->add_input(strided_slice->input(1));
fused_node->add_input(node->input(1));
AddNodeAttr("combiner", combiner, fused_node);
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(ss_reduce, 0, fused_node, 0, graph);
replace_all_users_with(node, 0, fused_node, 1, graph);
return true;
}
};
class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter {
public:
std::string name() const override {
return "KPFusedSparseSegmentReduceNonzero";
}
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(node->op() == "GatherND" &&
node->input_size() == 2)
const NodeDef* ss_reduce = get_node(node->input(0));
CHECK_NODE_OK(ss_reduce->input_size() == 3)
AttrValue combiner;
if (ss_reduce->op() == "SparseSegmentMean") {
combiner.set_i(1);
} else if (ss_reduce->op() == "SparseSegmentSum") {
combiner.set_i(0);
} else {
return false;
}
const NodeDef* where = get_node(node->input(1));
CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1)
const NodeDef* cast = get_user(where, 0, "Cast");
CHECK_NODE_OK(cast != nullptr)
const NodeDef* notequal = get_node(where->input(0));
CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2);
const NodeDef* zerolike = get_node(notequal->input(1));
CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1)
const NodeDef* cast_1 = get_node(ss_reduce->input(2));
CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
const NodeDef* strided_slice = get_node(cast->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
strided_slice->input_size() == 4)
const NodeDef* shape = get_user(ss_reduce, 0, "Shape");
CHECK_NODE_OK(shape != nullptr)
const NodeDef* cast_2 = get_user(shape, 0, "Cast");
CHECK_NODE_OK(cast_2 != nullptr)
CHECK_NODE_OK(
check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(ss_reduce->input(0));
fused_node->add_input(ss_reduce->input(1));
fused_node->add_input(strided_slice->input(0));
fused_node->add_input(strided_slice->input(1));
AddNodeAttr("combiner", combiner, fused_node);
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name() << "\n";
replace_all_users_with(cast_2, 0, fused_node, 0, graph);
replace_all_users_with(cast, 0, fused_node, 1, graph);
replace_all_users_with(node, 0, fused_node, 2, graph);
return true;
}
};
class KPFusedEmbeddingPaddingRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedEmbeddingPadding"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(IsReshape(*node))
const NodeDef* user = get_user(node, 0, "ConcatV2");
CHECK_NODE_OK(user != nullptr)
const NodeDef* concat = get_node(node->input(0));
CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(concat->input(2)), {0}))
const NodeDef* fill = get_node(concat->input(1));
CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
NodeDef* pack = get_operand(fill, "Pack");
CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2)
NodeDef* fill_const = get_operand(fill, "Const");
CHECK_NODE_OK(fill_const != nullptr &&
check_const_value<int>(fill_const, {0}))
NodeDef* sub = get_operand(pack, "Sub");
CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2)
const NodeDef* strided_slice = get_node(sub->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
strided_slice->input_size() == 4)
const NodeDef* cast = get_node(strided_slice->input(0));
CHECK_NODE_OK(IsCast(*cast))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(cast->input(0));
fused_node->add_input(concat->input(0));
fused_node->add_input(sub->input(1));
fused_node->add_input(node->input(1));
const NodeDef* pack_left = get_node(pack->input(0));
const NodeDef* pack_right = get_node(pack->input(1));
if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
fused_node->add_input(pack->input(0));
} else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
fused_node->add_input(pack->input(1));
} else {
return false;
}
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(sub, 0, fused_node, 0, graph);
replace_all_users_with(node, 0, fused_node, 1, graph);
return true;
}
};
class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedEmbeddingPaddingFast"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {0}))
CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(2)), {1}))
CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(3)), {1}))
CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
const NodeDef* shape = get_node(node->input(0));
CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
const NodeDef* reshape = get_node(shape->input(0));
CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
const NodeDef* concat = get_node(reshape->input(0));
CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
const NodeDef* fill = get_node(concat->input(1));
CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
const NodeDef* pack = get_node(fill->input(0));
CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
const NodeDef* sub = get_node(pack->input(0));
CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2)
const NodeDef* strided_slice = get_node(sub->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
strided_slice->input_size() == 4)
const NodeDef* cast = get_node(strided_slice->input(0));
CHECK_NODE_OK(IsCast(*cast))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(cast->input(0));
fused_node->add_input(concat->input(0));
fused_node->add_input(sub->input(1));
fused_node->add_input(reshape->input(1));
const NodeDef* pack_left = get_node(pack->input(0));
const NodeDef* pack_right = get_node(pack->input(1));
if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
fused_node->add_input(pack->input(0));
} else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
fused_node->add_input(pack->input(1));
} else {
return false;
}
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(sub, 0, fused_node, 0, graph);
replace_all_users_with(node, 0, fused_node, 1, graph);
return true;
}
};
class KPFusedSparseSelectRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedSparseSelect"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
const NodeDef* select_0 = get_node(node->input(0));
CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3)
const NodeDef* select_1 = get_node(select_0->input(2));
CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3)
const NodeDef* fill = get_node(select_1->input(1));
CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
CHECK_NODE_OK(
check_const_value<float>(get_mutable_node(fill->input(1)), {1.0f}))
const NodeDef* cast = get_node(select_1->input(2));
CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
const NodeDef* equal = get_node(select_1->input(0));
CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2)
NodeDef* reshape = get_operand(equal, "Reshape");
CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape))
const NodeDef* greater = get_node(cast->input(0));
CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2)
const NodeDef* reshape_4 = get_node(greater->input(0));
CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(reshape_4->input(1)), {-1, 1}))
const NodeDef* equal_1 = get_node(select_0->input(0));
CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2)
NodeDef* reshape_1 = get_operand(equal_1, "Reshape");
CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1))
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
const NodeDef* select_2 = get_node(node->input(1));
CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3)
const NodeDef* equal_2 = get_node(select_2->input(0));
CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2)
const NodeDef* fill_1 = get_node(select_2->input(2));
CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2)
CHECK_NODE_OK(
check_const_value<float>(get_mutable_node(fill_1->input(1)), {1.0f}))
const NodeDef* reshape_2 = get_operand(equal_2, "Reshape");
CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2))
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(reshape_2->input(1)), {-1, 1}))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(reshape->input(0));
fused_node->add_input(reshape_1->input(0));
fused_node->add_input(reshape_2->input(0));
std::vector<const NodeDef*> const_inputs = {equal, equal_1, equal_2,
greater};
for (const NodeDef* const_node : const_inputs) {
const NodeDef* left = get_node(const_node->input(0));
const NodeDef* right = get_node(const_node->input(1));
if (IsConstant(*left) || IsHostConstant(*left)) {
fused_node->add_input(const_node->input(0));
} else if (IsConstant(*right) || IsHostConstant(*right)) {
fused_node->add_input(const_node->input(1));
} else {
return false;
}
}
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(reshape, 0, fused_node, 0, graph);
replace_all_users_with(select_0, 0, fused_node, 1, graph);
replace_all_users_with(node, 0, fused_node, 2, graph);
return true;
}
};
class KPFusedGatherRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedGather"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3)
CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {0}))
const NodeDef* gather = get_node(node->input(0));
CHECK_NODE_OK(gather->op() == "GatherV2" &&
gather->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
const NodeDef* unique = get_node(node->input(1));
CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1)
const NodeDef* unique_1 = get_node(unique->input(0));
CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1)
const NodeDef* strided_slice = get_node(unique_1->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice))
CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(gather->input(0));
fused_node->add_input(strided_slice->input(0));
fused_node->add_input(strided_slice->input(1));
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(unique_1, 0, fused_node, 0, graph);
replace_all_users_with(unique_1, 1, fused_node, 1, graph);
replace_all_users_with(node, 0, fused_node, 2, graph);
return true;
}
};
class KPFusedSparseReshapeRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedSparseReshape"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3)
const NodeDef* concat = get_node(node->input(0));
CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(concat->input(2)), {-1}))
const NodeDef* reshape = get_node(concat->input(1));
CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(reshape->input(1)), {-1, 1}))
const NodeDef* strided_slice = get_node(reshape->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
strided_slice->input_size() == 4)
CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
const NodeDef* cast_1 = get_node(concat->input(0));
CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
const NodeDef* reshape_1 = get_node(cast_1->input(0));
CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
const NodeDef* range = get_node(reshape_1->input(0));
CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(range->input(0)), {0}))
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(range->input(2)), {1}))
const NodeDef* cast = get_node(node->input(1));
CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
const NodeDef* pack = get_node(cast->input(0));
CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
const NodeDef* strided_slice_1 = get_node(pack->input(0));
CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) &&
strided_slice_1->input_size() == 4)
CHECK_NODE_OK(check_const_value<int>(
get_mutable_node(strided_slice_1->input(1)), {0}))
CHECK_NODE_OK(check_const_value<int>(
get_mutable_node(strided_slice_1->input(2)), {1}))
CHECK_NODE_OK(check_const_value<int>(
get_mutable_node(strided_slice_1->input(3)), {1}))
CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1))
const NodeDef* shape = get_node(strided_slice_1->input(0));
CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(shape->input(0));
fused_node->add_input(strided_slice->input(1));
fused_node->add_input(node->input(2));
fused_node->add_input(pack->input(1));
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(node, 0, fused_node, 0, graph);
replace_all_users_with(node, 1, fused_node, 1, graph);
return true;
}
};
class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedEmbeddingActionIdGather"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(node->input(2)), {-1}))
const NodeDef* reshape = get_node(node->input(0));
CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
const NodeDef* gather = get_node(reshape->input(0));
const NodeDef* pack_1 = get_node(reshape->input(1));
CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(pack_1->input(1)), {-1}))
CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
const NodeDef* gather_1 = get_node(gather->input(0));
CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3)
CHECK_NODE_OK(
check_const_value<int>(get_mutable_node(gather_1->input(2)), {0}))
const NodeDef* fill = get_node(node->input(1));
CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
CHECK_NODE_OK(check_const_value<int>(get_mutable_node(fill->input(1)), {0}))
const NodeDef* pack = get_node(fill->input(0));
CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(gather_1->input(1));
fused_node->add_input(gather_1->input(0));
fused_node->add_input(gather->input(1));
fused_node->add_input(pack->input(0));
fused_node->add_input(pack->input(1));
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(node, 0, fused_node, 0, graph);
return true;
}
};
class KPFusedBatchMatMulAddSigmoidRewriter : public PatternRewriter {
public:
std::string name() const override { return "KPFusedBatchMatMulAddSigmoid"; }
bool match_and_rewrite(
const NodeDef* node, GraphDef* graph,
std::unordered_map<std::string, int>& node_indexes) override {
graph_ = graph;
indexes_ = &node_indexes;
CHECK_NODE_OK(node->op() == "Sigmoid" && node->input_size() == 1)
const NodeDef* add = get_node(node->input(0));
CHECK_NODE_OK(add->op() == "AddV2" && add->input_size() == 2)
int matmul_input_idx = -1;
int bias_input_idx = -1;
if (IsAnyBatchMatMul(*get_node(add->input(0)))) {
matmul_input_idx = 0;
bias_input_idx = 1;
} else if (IsAnyBatchMatMul(*get_node(add->input(1)))) {
matmul_input_idx = 1;
bias_input_idx = 0;
} else {
return false;
}
const NodeDef* batch_matmul = get_node(add->input(matmul_input_idx));
CHECK_NODE_OK(batch_matmul->input_size() == 2)
auto nodes = graph->mutable_node();
NodeDef* fused_node = nodes->Add();
fused_node->set_name(node->name() + fusion_appendix);
fused_node->set_op(name());
fused_node->set_device(node->device());
fused_node->add_input(batch_matmul->input(0));
fused_node->add_input(batch_matmul->input(1));
fused_node->add_input(add->input(bias_input_idx));
nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
<< fused_node->name();
replace_all_users_with(node, 0, fused_node, 0, graph);
return true;
}
};
void run_graph_optimization(GraphDef* graph) {
GraphOptimizer optimizer(graph);
const char* annc_fused_all = getenv("ANNC_FUSED_ALL");
const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH");
const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE");
const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING");
const char* annc_fused_emb_padding_fast =
getenv("ANNC_FUSED_EMD_PADDING_FAST");
const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT");
const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER");
const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE");
const char* annc_fused_emb_actionid_gather =
getenv("ANNC_FUSED_EMB_ACTIONID_GATHER");
const char* annc_fused_sps_reduce_nonzero =
getenv("ANNC_FUSED_SPS_REDUCE_NONZERO");
const char* annc_fused_batchmatmul_add_sigmoid =
getenv("ANNC_FUSED_BATCHMATMUL_ADD_SIGMOID");
bool enable_all =
(annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0;
if (enable_all || (annc_fused_sps_stitch != nullptr &&
strcmp(annc_fused_sps_stitch, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedSparseDynamicStitchRewriter>());
if (enable_all || (annc_fused_sps_reduce != nullptr &&
strcmp(annc_fused_sps_reduce, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedSparseSegmentReduceRewriter>());
if (enable_all || (annc_fused_emb_padding_fast != nullptr &&
strcmp(annc_fused_emb_padding_fast, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedEmbeddingPaddingFastRewriter>());
if (enable_all || (annc_fused_emb_padding != nullptr &&
strcmp(annc_fused_emb_padding, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedEmbeddingPaddingRewriter>());
if (enable_all || (annc_fused_sps_select != nullptr &&
strcmp(annc_fused_sps_select, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedSparseSelectRewriter>());
if (enable_all ||
(annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0))
optimizer.register_rewriter(std::make_unique<KPFusedGatherRewriter>());
if (enable_all || (annc_fused_sps_reshape != nullptr &&
strcmp(annc_fused_sps_reshape, "1") == 0))
optimizer.register_rewriter(std::make_unique<KPFusedSparseReshapeRewriter>());
if (annc_fused_emb_actionid_gather != nullptr &&
strcmp(annc_fused_emb_actionid_gather, "1") == 0)
optimizer.register_rewriter(
std::make_unique<KPFusedEmbeddingActionIdGatherRewriter>());
if (annc_fused_sps_reduce_nonzero != nullptr &&
strcmp(annc_fused_sps_reduce_nonzero, "1") == 0)
optimizer.register_rewriter(
std::make_unique<KPFusedSparseSegmentReduceNonzeroRewriter>());
if (enable_all || (annc_fused_batchmatmul_add_sigmoid != nullptr &&
strcmp(annc_fused_batchmatmul_add_sigmoid, "1") == 0))
optimizer.register_rewriter(
std::make_unique<KPFusedBatchMatMulAddSigmoidRewriter>());
optimizer.optimize();
}
}