@@ -673,7 +673,13 @@ cc_library(
]) + if_tensorrt([
"//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
- ]) + tf_tpu_dependencies() + tf_dtensor_tpu_dependencies(),
+ ]) + tf_tpu_dependencies() + tf_dtensor_tpu_dependencies()
+ + select({
+ "@platforms//cpu:aarch64": [
+ "//tensorflow/core/kernels:embedding_fused_ops"
+ ],
+ "//conditions:default": [],
+ }),
)
cc_library(
@@ -920,6 +926,7 @@ filegroup(
"candidate_sampling_ops_op_lib",
"checkpoint_ops_op_lib",
"clustering_ops_op_lib",
+ "embedding_fused_ops_op_lib",
"collective_ops_op_lib",
"control_flow_ops_op_lib",
"count_ops_op_lib",
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingActionIdGather"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingPadding"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingPaddingFast"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedGather"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseDynamicStitch"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseReshape"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSegmentReduce"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSegmentReduceNonzero"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSelect"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingActionIdGather"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingPadding"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedEmbeddingPaddingFast"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedGather"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseDynamicStitch"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseReshape"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSegmentReduce"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSegmentReduceNonzero"
+}
new file mode 100644
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "KPFusedSparseSelect"
+}
@@ -904,7 +904,11 @@ tf_kernel_library(
"//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
"@com_google_absl//absl/container:flat_hash_set",
- ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]),
+ ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"])
+ + select({
+ "@platforms//cpu:aarch64": ["//tensorflow/core/grappler/optimizers/graph_optimizer:annc_graph_opt"],
+ "//conditions:default": [],
+ }),
)
tf_cuda_cc_test(
new file mode 100644
@@ -0,0 +1,21 @@
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"],
+)
+
+cc_library(
+ name = "annc_graph_opt",
+ srcs = glob(["*.cc"]),
+ hdrs = glob(["*.h"]),
+ linkstatic = True,
+ alwayslink = True,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ ],
+)
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,791 @@
+#include "graph_opt.h"
+
+using namespace tensorflow;
+using namespace tensorflow::grappler;
+
+namespace annc {
+void update_node_indexes(const GraphDef* graph,
+ std::unordered_map<std::string, int>& node_indexes) {
+ for (int i = 0; i < graph->node_size(); ++i) {
+ node_indexes[graph->node(i).name()] = i;
+ }
+}
+
+void GraphOptimizer::register_rewriter(
+ std::unique_ptr<PatternRewriter> rewriter) {
+ rewriters_.push_back(std::move(rewriter));
+}
+
+void GraphOptimizer::optimize() {
+ update_node_indexes(graph_, node_indexes_);
+ int node_index = 0;
+ const int node_size = graph_->node_size();
+ while (node_index < node_size) {
+ const NodeDef& node = graph_->node(node_index);
+ const std::string& node_name = node.name();
+ for (auto& rewriter : rewriters_) {
+ if (rewriter->match_and_rewrite(&node, graph_, props_, node_indexes_)) {
+ update_node_indexes(graph_, node_indexes_);
+ const std::string new_node_name = node_name + fusion_appendix;
+ node_index = node_indexes_.at(new_node_name);
+ break;
+ }
+ }
+ node_index++;
+ }
+}
+
+std::string get_node_name(const std::string& name) {
+ size_t colon_pos = name.find_last_of(':');
+ std::string node_name = name;
+ if (colon_pos != std::string::npos) {
+ node_name = name.substr(0, colon_pos);
+ }
+ return node_name;
+}
+
+void set_fusedop_attributes(NodeDef* fused,
+ const absl::Span<const absl::string_view> fused_ops,
+ int num_args = 1, float epsilon = 0.0) {
+ auto* attr = fused->mutable_attr();
+ SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
+ SetAttrValue(num_args, &(*attr)["num_args"]);
+ SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
+}
+
+const NodeDef* PatternRewriter::get_node(const std::string& name) {
+ const std::string node_name = get_node_name(name);
+ const int node_index = indexes_->at(node_name);
+ return &graph_->node(node_index);
+}
+
+NodeDef* PatternRewriter::get_mutable_node(const std::string& name) {
+ const std::string node_name = get_node_name(name);
+ if (indexes_->find(node_name) == indexes_->end()) return nullptr;
+ const int node_index = indexes_->at(node_name);
+ return graph_->mutable_node(node_index);
+}
+
+NodeDef* PatternRewriter::get_operand(const NodeDef* node,
+ std::string op_type) {
+ for (int i = 0; i < node->input_size(); ++i) {
+ NodeDef* operand = get_mutable_node(node->input(i));
+ if (operand != nullptr && operand->op() == op_type) return operand;
+ }
+ return nullptr;
+}
+
+const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index,
+ const std::string& op_type) {
+ std::string node_name = node->name();
+ if (index) node_name = node_name + ":" + std::to_string(index);
+ for (int i = 0; i < graph_->node_size(); ++i) {
+ const NodeDef* node = &graph_->node(i);
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (node->input(j) == node_name && node->op() == op_type) {
+ return node;
+ }
+ }
+ }
+ return nullptr;
+}
+
+void PatternRewriter::replace_all_users_with(const NodeDef* old_node,
+ int old_index,
+ const NodeDef* new_node,
+ int new_index, GraphDef* graph) {
+ std::string old_name = old_node->name();
+ if (old_index) old_name = old_name + ":" + std::to_string(old_index);
+ std::string new_name = new_node->name();
+ if (new_index) new_name = new_name + ":" + std::to_string(new_index);
+ for (int i = 0; i < graph->node_size(); ++i) {
+ NodeDef* node = graph->mutable_node(i);
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (node->input(j) == old_name) {
+ node->set_input(j, new_name);
+ }
+ }
+ }
+}
+
+class KPFusedSparseDynamicStitchRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedSparseDynamicStitch"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" &&
+ node->input_size() % 2 == 0)
+ int num_inputs = node->input_size();
+ int num_partitions = num_inputs / 2;
+ // left branch
+ const NodeDef* partition = get_node(node->input(0));
+ CHECK_NODE_OK(partition->op() == "DynamicPartition" &&
+ partition->input_size() == 2)
+ const NodeDef* range = get_node(partition->input(0));
+ CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
+ // Range start=0, delta=1
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(range->input(0)), {0}))
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(range->input(2)), {1}))
+ const NodeDef* size = get_node(range->input(1));
+ CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1)
+ const NodeDef* cast = get_node(partition->input(1));
+ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+ const NodeDef* floor_mod = get_node(cast->input(0));
+ CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2)
+ CHECK_NODE_OK(check_const_value<int64_t>(
+ get_mutable_node(floor_mod->input(1)), {num_partitions}))
+
+ CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions}))
+ CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions}))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(size->input(0));
+ // right branch
+ for (int i = num_partitions; i < num_inputs; ++i) {
+ const NodeDef* gather = get_node(node->input(i));
+ CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
+ // Gather axis=0
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+ const NodeDef* partition_1 = get_node(gather->input(1));
+ CHECK_NODE_OK(partition_1->op() == "DynamicPartition" &&
+ partition_1->input_size() == 2)
+ CHECK_NODE_OK(
+ check_int_attr(partition_1, "num_partitions", {num_partitions}))
+ const NodeDef* floor_div = get_node(partition_1->input(0));
+ CHECK_NODE_OK(floor_div->op() == "FloorDiv" &&
+ floor_div->input_size() == 2)
+ CHECK_NODE_OK(check_const_value<int64_t>(
+ get_mutable_node(floor_div->input(1)), {num_partitions}))
+ CHECK_NODE_OK(check_input_dims(props, gather, 0, 2));
+ fused_node->add_input(gather->input(0));
+ }
+ (*fused_node->mutable_attr())["N"].set_i(num_partitions);
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(node, 0, fused_node, 0, graph);
+ return true;
+ }
+};
+
+class KPFusedSparseSegmentReduceRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedSparseSegmentReduce"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
+ const NodeDef* shape = get_node(node->input(0));
+ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+ const NodeDef* ss_reduce = get_node(shape->input(0));
+ CHECK_NODE_OK(ss_reduce->input_size() == 3)
+ AttrValue combiner;
+ if (ss_reduce->op() == "SparseSegmentMean")
+ combiner.set_i(1);
+ else if (ss_reduce->op() == "SparseSegmentSum")
+ combiner.set_i(0);
+ else
+ return false;
+ const NodeDef* cast = get_node(ss_reduce->input(2));
+ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+ const NodeDef* strided_slice = get_node(cast->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+ strided_slice->input_size() == 4)
+
+ // check fusion conditions
+ CHECK_NODE_OK(
+ check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+ CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1}))
+ CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
+
+ CHECK_NODE_OK(check_input_dims(props, ss_reduce, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(ss_reduce->input(0)); // input_tensor
+ fused_node->add_input(ss_reduce->input(1)); // indices
+ fused_node->add_input(strided_slice->input(0)); // slice_input
+ fused_node->add_input(strided_slice->input(1)); // begin
+ fused_node->add_input(node->input(1)); // begin_1
+ AddNodeAttr("combiner", combiner, fused_node);
+ AddNodeAttr("Tidx", ss_reduce->attr().at("Tidx"), fused_node);
+
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(ss_reduce, 0, fused_node, 0, graph);
+ replace_all_users_with(node, 0, fused_node, 1, graph);
+ return true;
+ }
+};
+
+class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter {
+ public:
+ std::string name() const override {
+ return "KPFusedSparseSegmentReduceNonzero";
+ }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(node->op() == "GatherNd" &&
+ node->input_size() == 2) // output:2
+ const NodeDef* ss_reduce = get_node(node->input(0));
+ CHECK_NODE_OK(ss_reduce->input_size() == 3)
+ AttrValue combiner;
+ if (ss_reduce->op() == "SparseSegmentMean") {
+ combiner.set_i(1);
+ } else if (ss_reduce->op() == "SparseSegmentSum") {
+ combiner.set_i(0);
+ } else {
+ return false;
+ }
+ const NodeDef* where = get_node(node->input(1));
+ CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1)
+ const NodeDef* cast = get_user(where, 0, "Cast");
+ CHECK_NODE_OK(cast != nullptr) // output: 1
+ const NodeDef* notequal = get_node(where->input(0));
+ CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2);
+ const NodeDef* zerolike = get_node(notequal->input(1));
+ CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1)
+ const NodeDef* cast_1 = get_node(ss_reduce->input(2));
+ CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
+ const NodeDef* strided_slice = get_node(cast_1->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+ strided_slice->input_size() == 4)
+ const NodeDef* shape = get_user(ss_reduce, 0, "Shape");
+ CHECK_NODE_OK(shape != nullptr)
+ const NodeDef* cast_2 = get_user(shape, 0, "Cast"); // output: 0
+ CHECK_NODE_OK(cast_2 != nullptr)
+
+ CHECK_NODE_OK(
+ check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+
+ CHECK_NODE_OK(check_input_dims(props, ss_reduce, 0, 1) ||
+ check_input_dims(props, ss_reduce, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(ss_reduce->input(0));
+ fused_node->add_input(ss_reduce->input(1));
+ fused_node->add_input(strided_slice->input(0));
+ fused_node->add_input(strided_slice->input(1));
+ AddNodeAttr("combiner", combiner, fused_node);
+ AddNodeAttr("Tidx", ss_reduce->attr().at("Tidx"), fused_node);
+
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name() << "\n";
+ replace_all_users_with(cast_2, 0, fused_node, 0, graph);
+ replace_all_users_with(cast, 0, fused_node, 1, graph);
+ replace_all_users_with(node, 0, fused_node, 2, graph);
+ return true;
+ }
+};
+
+class KPFusedEmbeddingPaddingRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedEmbeddingPadding"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(IsReshape(*node))
+ const NodeDef* user = get_user(node, 0, "ConcatV2");
+ CHECK_NODE_OK(user != nullptr)
+ const NodeDef* concat = get_node(node->input(0));
+ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(concat->input(2)), {0}))
+ const NodeDef* fill = get_node(concat->input(1));
+ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+ NodeDef* pack = get_operand(fill, "Pack");
+ CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2)
+ NodeDef* fill_const = get_operand(fill, "Const");
+ CHECK_NODE_OK(fill_const != nullptr &&
+ check_const_value<int>(fill_const, {0}))
+ NodeDef* sub = get_operand(pack, "Sub");
+ CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2)
+ const NodeDef* strided_slice = get_node(sub->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+ strided_slice->input_size() == 4)
+ const NodeDef* cast = get_node(strided_slice->input(0));
+ CHECK_NODE_OK(IsCast(*cast))
+
+ CHECK_NODE_OK(check_input_dims(props, cast, 0, 1))
+ CHECK_NODE_OK(check_input_dims(props, concat, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, sub, 1, 0))
+ CHECK_NODE_OK(check_input_dims(props, node, 1, 1))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(cast->input(0)); // origin_shape
+ fused_node->add_input(concat->input(0)); // input
+ fused_node->add_input(sub->input(1)); // input_rows
+ fused_node->add_input(node->input(1)); // reshape_sizes
+ const NodeDef* pack_left = get_node(pack->input(0));
+ const NodeDef* pack_right = get_node(pack->input(1));
+ AddNodeAttr("T", pack->attr().at("T"), fused_node);
+ if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
+ fused_node->add_input(pack->input(0));
+ } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
+ fused_node->add_input(pack->input(1));
+ } else {
+ return false;
+ }
+
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(sub, 0, fused_node, 0, graph);
+ replace_all_users_with(node, 0, fused_node, 1, graph);
+ return true;
+ }
+};
+
+class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedEmbeddingPaddingFast"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
+ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(1)), {0}))
+ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {1}))
+ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(3)), {1}))
+ CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
+ const NodeDef* shape = get_node(node->input(0));
+ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+ const NodeDef* reshape = get_node(shape->input(0));
+ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+ const NodeDef* concat = get_node(reshape->input(0));
+ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+ const NodeDef* fill = get_node(concat->input(1));
+ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+ const NodeDef* pack = get_node(fill->input(0));
+ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+ const NodeDef* sub = get_node(pack->input(0));
+ CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2)
+ const NodeDef* strided_slice = get_node(sub->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+ strided_slice->input_size() == 4)
+ const NodeDef* cast = get_node(strided_slice->input(0));
+ CHECK_NODE_OK(IsCast(*cast))
+
+ CHECK_NODE_OK(check_input_dims(props, cast, 0, 1))
+ CHECK_NODE_OK(check_input_dims(props, concat, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, sub, 1, 0))
+ CHECK_NODE_OK(check_input_dims(props, reshape, 1, 1))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(cast->input(0));
+ fused_node->add_input(concat->input(0));
+ fused_node->add_input(sub->input(1));
+ fused_node->add_input(reshape->input(1));
+ const NodeDef* pack_left = get_node(pack->input(0));
+ const NodeDef* pack_right = get_node(pack->input(1));
+ if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
+ fused_node->add_input(pack->input(0));
+ } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
+ fused_node->add_input(pack->input(1));
+ } else {
+ return false;
+ }
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(sub, 0, fused_node, 0, graph);
+ replace_all_users_with(node, 0, fused_node, 1, graph);
+ return true;
+ }
+};
+
+class KPFusedSparseSelectRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedSparseSelect"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
+ const NodeDef* select_0 = get_node(node->input(0));
+ CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3)
+ const NodeDef* select_1 = get_node(select_0->input(2));
+ CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3)
+ const NodeDef* fill = get_node(select_1->input(1));
+ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<float>(get_mutable_node(fill->input(1)), {1.0f}))
+ const NodeDef* cast = get_node(select_1->input(2));
+ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+ const NodeDef* equal = get_node(select_1->input(0));
+ CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2)
+ NodeDef* reshape = get_operand(equal, "Reshape");
+ CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape))
+ const NodeDef* greater = get_node(cast->input(0));
+ CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2)
+ const NodeDef* reshape_4 = get_node(greater->input(0));
+ CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(reshape_4->input(1)), {-1, 1}))
+ const NodeDef* equal_1 = get_node(select_0->input(0));
+ CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2)
+ NodeDef* reshape_1 = get_operand(equal_1, "Reshape");
+ CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1))
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
+
+ // right branch
+ const NodeDef* select_2 = get_node(node->input(1));
+ CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3)
+ const NodeDef* equal_2 = get_node(select_2->input(0));
+ CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2)
+ const NodeDef* fill_1 = get_node(select_2->input(2));
+ CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<float>(get_mutable_node(fill_1->input(1)), {1.0f}))
+ const NodeDef* reshape_2 = get_operand(equal_2, "Reshape");
+ CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2))
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(reshape_2->input(1)), {-1, 1}))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(reshape_4->input(0));
+ fused_node->add_input(reshape_1->input(0));
+ fused_node->add_input(reshape_2->input(0));
+ std::vector<const NodeDef*> const_inputs = {
+ greater,
+ equal,
+ equal_1,
+ equal_2,
+ };
+ for (const NodeDef* const_node : const_inputs) {
+ const NodeDef* left = get_node(const_node->input(0));
+ const NodeDef* right = get_node(const_node->input(1));
+ if (IsConstant(*left) || IsHostConstant(*left)) {
+ fused_node->add_input(const_node->input(0));
+ } else if (IsConstant(*right) || IsHostConstant(*right)) {
+ fused_node->add_input(const_node->input(1));
+ } else {
+ return false;
+ }
+ }
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(reshape_4, 0, fused_node, 0, graph);
+ replace_all_users_with(select_0, 0, fused_node, 1, graph);
+ replace_all_users_with(node, 0, fused_node, 2, graph);
+ return true;
+ }
+};
+
+class KPFusedGatherRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedGather"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3)
+ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {0}))
+ const NodeDef* gather = get_node(node->input(0));
+ CHECK_NODE_OK(gather->op() == "GatherV2" &&
+ gather->input_size() == 3) // input:0
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+ const NodeDef* unique = get_node(node->input(1)); // output:1
+ CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1)
+ const NodeDef* unique_1 = get_node(unique->input(0));
+ CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1)
+ const NodeDef* strided_slice = get_node(unique_1->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice)) // input:1 2
+ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+
+ CHECK_NODE_OK(check_input_dims(props, gather, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(gather->input(0));
+ fused_node->add_input(strided_slice->input(0));
+ fused_node->add_input(strided_slice->input(1));
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(unique_1, 0, fused_node, 0, graph);
+ replace_all_users_with(unique_1, 1, fused_node, 1, graph);
+ replace_all_users_with(node, 0, fused_node, 2, graph);
+ return true;
+ }
+};
+
+class KPFusedSparseReshapeRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedSparseReshape"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3)
+ const NodeDef* concat = get_node(node->input(0));
+ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(concat->input(2)), {-1}))
+ const NodeDef* reshape = get_node(concat->input(1));
+ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(reshape->input(1)), {-1, 1}))
+ const NodeDef* strided_slice = get_node(reshape->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+ strided_slice->input_size() == 4)
+ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+ const NodeDef* cast_1 = get_node(concat->input(0));
+ CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
+ const NodeDef* reshape_1 = get_node(cast_1->input(0));
+ CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
+ const NodeDef* range = get_node(reshape_1->input(0));
+ CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
+ // Range start=0, delta=1
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(range->input(0)), {0}))
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(range->input(2)), {1}))
+ const NodeDef* cast = get_node(node->input(1));
+ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+ const NodeDef* pack = get_node(cast->input(0));
+ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+ const NodeDef* strided_slice_1 = get_node(pack->input(0));
+ CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) &&
+ strided_slice_1->input_size() == 4)
+ CHECK_NODE_OK(check_const_value<int>(
+ get_mutable_node(strided_slice_1->input(1)), {0}))
+ CHECK_NODE_OK(check_const_value<int>(
+ get_mutable_node(strided_slice_1->input(2)), {1}))
+ CHECK_NODE_OK(check_const_value<int>(
+ get_mutable_node(strided_slice_1->input(3)), {1}))
+ CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1))
+ const NodeDef* shape = get_node(strided_slice_1->input(0));
+ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+
+ CHECK_NODE_OK(check_input_dims(props, shape, 0, 2))
+ CHECK_NODE_OK(check_input_dims(props, pack, 1, 0))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(shape->input(0)); // slice_input
+ fused_node->add_input(strided_slice->input(1)); // begin
+ fused_node->add_input(node->input(2)); // new_shape
+ fused_node->add_input(pack->input(1)); // pack_const
+ AddNodeAttr("T", pack->attr().at("T"), fused_node);
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(node, 0, fused_node, 0, graph);
+ replace_all_users_with(node, 1, fused_node, 1, graph);
+ return true;
+ }
+};
+
+class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter {
+ public:
+ std::string name() const override { return "KPFusedEmbeddingActionIdGather"; }
+
+ bool match_and_rewrite(
+ const NodeDef* node, GraphDef* graph,
+ const GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) override {
+ graph_ = graph;
+ indexes_ = &node_indexes;
+ CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(node->input(2)), {-1}))
+ const NodeDef* reshape = get_node(node->input(0));
+ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+ const NodeDef* gather = get_node(reshape->input(0));
+ const NodeDef* pack_1 = get_node(reshape->input(1));
+ CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(pack_1->input(1)), {-1}))
+ CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+ const NodeDef* gather_1 = get_node(gather->input(0));
+ CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3)
+ CHECK_NODE_OK(
+ check_const_value<int>(get_mutable_node(gather_1->input(2)), {0}))
+ const NodeDef* fill = get_node(node->input(1));
+ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(fill->input(1)), {0}))
+ const NodeDef* pack = get_node(fill->input(0));
+ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+
+ CHECK_NODE_OK(check_input_dims(props, gather_1, 1, 1) ||
+ check_input_dims(props, gather_1, 1, 2))
+ CHECK_NODE_OK(check_input_dims(props, gather, 1, 1 ) ||
+ check_input_dims(props, gather, 1, 2))
+ CHECK_NODE_OK(check_input_dims(props, gather_1, 0, 2) ||
+ check_input_dims(props, gather_1, 0, 3))
+ CHECK_NODE_OK(check_input_dims(props, pack, 0, 0))
+ CHECK_NODE_OK(check_input_dims(props, pack, 1, 0))
+
+ auto nodes = graph->mutable_node();
+ NodeDef* fused_node = nodes->Add();
+ fused_node->set_name(node->name() + fusion_appendix);
+ fused_node->set_op(name());
+ fused_node->set_device(node->device());
+ fused_node->add_input(gather_1->input(1)); // indices1
+ fused_node->add_input(gather_1->input(0)); // params
+ fused_node->add_input(gather->input(1)); // indices2
+ fused_node->add_input(pack->input(0)); // pack_dim
+ fused_node->add_input(pack->input(1)); // pack
+ AddNodeAttr("Tindices1", gather_1->attr().at("Tindices"), fused_node);
+ AddNodeAttr("Tindices2", gather->attr().at("Tindices"), fused_node);
+ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+ << fused_node->name();
+ replace_all_users_with(node, 0, fused_node, 0, graph);
+ return true;
+ }
+};
+
+void run_graph_optimization(GraphDef* graph, GraphProperties props) {
+ GraphOptimizer optimizer(graph, props);
+
+ const char* annc_fused_all = getenv("ANNC_FUSED_ALL");
+ const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH");
+ const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE");
+ const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING");
+ const char* annc_fused_emb_padding_fast =
+ getenv("ANNC_FUSED_EMD_PADDING_FAST");
+ const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT");
+ const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER");
+ const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE");
+ const char* annc_fused_emb_actionid_gather =
+ getenv("ANNC_FUSED_EMB_ACTIONID_GATHER");
+ const char* annc_fused_sps_reduce_nonzero =
+ getenv("ANNC_FUSED_SPS_REDUCE_NONZERO");
+
+ bool enable_all =
+ (annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0;
+
+ // default enable all rewriters
+ if (enable_all || (annc_fused_sps_stitch != nullptr &&
+ strcmp(annc_fused_sps_stitch, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedSparseDynamicStitchRewriter>());
+ if (enable_all || (annc_fused_sps_reduce != nullptr &&
+ strcmp(annc_fused_sps_reduce, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedSparseSegmentReduceRewriter>());
+ if (enable_all || (annc_fused_emb_padding_fast != nullptr &&
+ strcmp(annc_fused_emb_padding_fast, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedEmbeddingPaddingFastRewriter>());
+ if (enable_all || (annc_fused_emb_padding != nullptr &&
+ strcmp(annc_fused_emb_padding, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedEmbeddingPaddingRewriter>());
+ if (enable_all || (annc_fused_sps_select != nullptr &&
+ strcmp(annc_fused_sps_select, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedSparseSelectRewriter>());
+ if (enable_all ||
+ (annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0))
+ optimizer.register_rewriter(std::make_unique<KPFusedGatherRewriter>());
+ if (enable_all || (annc_fused_sps_reshape != nullptr &&
+ strcmp(annc_fused_sps_reshape, "1") == 0))
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedSparseReshapeRewriter>());
+ if (enable_all || annc_fused_emb_actionid_gather != nullptr &&
+ strcmp(annc_fused_emb_actionid_gather, "1") == 0)
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedEmbeddingActionIdGatherRewriter>());
+ if (enable_all || annc_fused_sps_reduce_nonzero != nullptr &&
+ strcmp(annc_fused_sps_reduce_nonzero, "1") == 0)
+ optimizer.register_rewriter(
+ std::make_unique<KPFusedSparseSegmentReduceNonzeroRewriter>());
+ optimizer.optimize();
+}
+} // namespace annc
new file mode 100644
@@ -0,0 +1,178 @@
+#ifndef ANNC_TF_GRAPH_OPT_H_
+#define ANNC_TF_GRAPH_OPT_H_
+#include <type_traits>
+#include <unordered_map>
+
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+
+namespace annc {
+#define CHECK_NODE_OK(x) \
+ if (!(x)) { \
+ return false; \
+ }
+
+static const std::string fusion_appendix = "/kp_fused";
+
+void update_node_indexes(const tensorflow::GraphDef* graph,
+ std::unordered_map<std::string, int>& node_indexes);
+
+class PatternRewriter {
+ public:
+ PatternRewriter() {}
+ virtual ~PatternRewriter() = default;
+
+ virtual bool match_and_rewrite(
+ const tensorflow::NodeDef* node, tensorflow::GraphDef* graph,
+ const tensorflow::grappler::GraphProperties& props,
+ std::unordered_map<std::string, int>& node_indexes) = 0;
+
+ virtual std::string name() const { return "PatternRewriter"; };
+
+ const tensorflow::NodeDef* get_node(const std::string& name);
+ tensorflow::NodeDef* get_mutable_node(const std::string& name);
+
+ tensorflow::NodeDef* get_operand(const tensorflow::NodeDef* node, std::string op_type);
+
+ const tensorflow::NodeDef* get_user(const tensorflow::NodeDef* node, int index,
+ const std::string& op_type);
+
+ void replace_all_users_with(const tensorflow::NodeDef* old_node, int old_index,
+ const tensorflow::NodeDef* new_node, int new_index,
+ tensorflow::GraphDef* graph);
+
+ bool check_input_dims(const tensorflow::grappler::GraphProperties& graph_properties,
+ const tensorflow::NodeDef* op, int input_index,
+ int expected_dim_size) {
+ const auto& input_props = graph_properties.GetInputProperties(op->name());
+ if (input_index >= static_cast<int>(input_props.size())) {
+ return false;
+ }
+ const tensorflow::TensorShapeProto& shape = input_props[input_index].shape();
+ std::string shape_str = "[";
+ if (shape.unknown_rank()) {
+ shape_str = "[?]"; // rank
+ } else {
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ if (i > 0) shape_str += ", ";
+ // -1 "?"
+ auto dim_size = shape.dim(i).size();
+ if (dim_size == -1) {
+ shape_str += "?";
+ } else {
+ shape_str += std::to_string(dim_size);
+ }
+ }
+ shape_str += "]";
+ }
+
+ LOG(INFO) << " Full input shape: " << shape_str
+ << ", rank: " << (shape.unknown_rank() ? -1 : shape.dim_size())
+ << ", expected rank: " << expected_dim_size;
+
+ return shape.dim_size() == expected_dim_size;
+ }
+ bool check_const_dims(tensorflow::NodeDef* op, int dim_size) {
+ if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+ HasNodeAttr(*op, "value")))
+ return false;
+
+ tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+ const auto& shape = tensor->tensor_shape();
+ if (shape.dim_size() != static_cast<int>(dim_size)) return false;
+ return true;
+ }
+
+ bool check_const_shape(tensorflow::NodeDef* op, std::vector<int> dims) {
+ if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+ HasNodeAttr(*op, "value")))
+ return false;
+
+ tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+ const auto& shape = tensor->tensor_shape();
+ if (shape.dim_size() != static_cast<int>(dims.size())) return false;
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ if (shape.dim(i).size() != dims[i]) return false;
+ }
+ return true;
+ }
+
+ template <typename T>
+ bool check_const_value(tensorflow::NodeDef* op, std::vector<T> cmp) {
+ if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+ HasNodeAttr(*op, "value")))
+ return false;
+
+ tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+ const auto& shape = tensor->tensor_shape();
+ int dim_size = 1;
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ dim_size *= shape.dim(i).size();
+ }
+ if (dim_size < static_cast<int>(cmp.size())) return false;
+
+ if (std::is_same<T, float>::value) {
+ const float* data = tensor->mutable_float_val()->data();
+ if (data == nullptr)
+ data = reinterpret_cast<const float*>(tensor->tensor_content().data());
+ if (data == nullptr) return false;
+ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+ if (std::fabs(data[i] - cmp[i]) >= 1e-5f) return false;
+ }
+ } else if (std::is_same<T, int>::value) {
+ const int* data = tensor->mutable_int_val()->data();
+ if (data == nullptr)
+ data = reinterpret_cast<const int*>(tensor->tensor_content().data());
+ if (data == nullptr) return false;
+ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+ if (data[i] != cmp[i]) return false;
+ }
+ } else if (std::is_same<T, int64_t>::value) {
+ const int64_t* data = tensor->mutable_int64_val()->data();
+ if (data == nullptr)
+ data =
+ reinterpret_cast<const int64_t*>(tensor->tensor_content().data());
+ if (data == nullptr) return false;
+ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+ if (data[i] != cmp[i]) return false;
+ }
+ } else {
+ // data type do not support
+ return false;
+ }
+ return true;
+ }
+
+ bool check_int_attr(const tensorflow::NodeDef* op, std::string name, int value) {
+ if (HasNodeAttr(*op, name)) {
+ tensorflow::AttrValue attr = op->attr().at(name);
+ if (attr.value_case() == tensorflow::AttrValue::kI && attr.i() == value) return true;
+ }
+ return false;
+ }
+
+ tensorflow::GraphDef* graph_;
+ std::unordered_map<std::string, int>* indexes_;
+};
+
+class GraphOptimizer {
+ public:
+ GraphOptimizer(tensorflow::GraphDef* graph, tensorflow::grappler::GraphProperties graph_properties) : graph_(graph), props_(graph_properties) {}
+ virtual ~GraphOptimizer() = default;
+
+ void register_rewriter(std::unique_ptr<PatternRewriter> rewriter);
+
+ void optimize();
+
+ private:
+ tensorflow::GraphDef* graph_;
+ tensorflow::grappler::GraphProperties props_;
+ std::unordered_map<std::string, int> node_indexes_;
+ std::vector<std::unique_ptr<PatternRewriter>> rewriters_;
+};
+
+void run_graph_optimization(tensorflow::GraphDef* graph, tensorflow::grappler::GraphProperties props);
+} // namespace annc
+#endif // ANNC_TF_GRAPH_OPT_H_
@@ -43,6 +43,9 @@ limitations under the License.
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tsl/platform/errors.h"
+#if defined(__aarch64__)
+#include "tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h"
+#endif // __aarch64__
#ifdef INTEL_MKL
#include "tensorflow/core/util/mkl_heuristics.h"
#endif // INTEL_MKL
@@ -4941,6 +4944,19 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
TF_RETURN_IF_ERROR(mutation->Apply());
+#if defined(__aarch64__)
+ // ========== infer shape ==========
+ tensorflow::grappler::GraphProperties graph_props(mutable_item);
+ bool assume_valid_feeds = (opt_level_ == RewriterConfig::AGGRESSIVE);
+ TF_RETURN_IF_ERROR(graph_props.InferStatically(
+ assume_valid_feeds,
+ /*aggressive_shape_inference=*/false,
+ /*include_input_tensor_values=*/true,
+ /*include_output_tensor_values=*/true));
+
+ annc::run_graph_optimization(&mutable_item.graph, graph_props);
+#endif // __aarch64__
+
*optimized_graph = std::move(mutable_item.graph);
return OkStatus();
@@ -3664,6 +3664,97 @@ tf_kernel_library(
]) + [":fft_impl"],
)
+tf_kernel_library(
+ name = "embedding_fused_action_id_gather_op",
+ srcs = ["embedding_fused_action_id_gather.cc"],
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "embedding_fused_gather_op",
+ srcs = ["embedding_fused_gather.cc"],
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "embedding_fused_padding_op",
+ srcs = ["embedding_fused_padding.cc"],
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "embedding_fused_sparse_dynamic_stitch_op",
+ srcs = ["embedding_fused_sparse_dynamic_stitch.cc"],
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "embedding_fused_reshape_op",
+ srcs = ["embedding_fused_sparse_reshape.cc"],
+ deps = MATH_DEPS + [
+ ":reshape_util",
+ ],
+)
+
+tf_kernel_library(
+ name = "embedding_fused_sparse_segment_reduce_op",
+ srcs = ["embedding_fused_sparse_segment_reduce.cc"],
+ deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+ name = "embedding_fused_sparse_segment_reduce_nonzero_op",
+ srcs = ["embedding_fused_sparse_segment_reduce_nonzero.cc"],
+ deps = MATH_DEPS + ["@com_google_absl//absl/container:flat_hash_map"],
+)
+
+tf_kernel_library(
+ name = "embedding_fused_sparse_select_op",
+ srcs = ["embedding_fused_sparse_select.cc"],
+ deps = MATH_DEPS,
+)
+
+cc_library(
+ name = "embedding_fused_ops",
+ deps = [
+ ":embedding_fused_action_id_gather_op",
+ ":embedding_fused_gather_op",
+ ":embedding_fused_padding_op",
+ ":embedding_fused_sparse_dynamic_stitch_op",
+ ":embedding_fused_reshape_op",
+ ":embedding_fused_sparse_segment_reduce_op",
+ ":embedding_fused_sparse_segment_reduce_nonzero_op",
+ ":embedding_fused_sparse_select_op",
+ ],
+)
+
+tf_cc_test(
+ name = "embedding_fused_ops_test",
+ size = "small",
+ srcs = [
+ "embedding_fused_action_id_gather_test.cc",
+ "embedding_fused_sparse_dynamic_stitch_test.cc",
+ "embedding_fused_sparse_segment_reduce_test.cc",
+ "embedding_fused_sparse_segment_reduce_nonzero_test.cc",
+ "embedding_fused_padding_test.cc",
+ "embedding_fused_sparse_select_test.cc",
+ "embedding_fused_gather_test.cc",
+ "embedding_fused_sparse_reshape_test.cc",
+ ],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":embedding_fused_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_kernel_library(
name = "reduction_ops",
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
new file mode 100644
@@ -0,0 +1,263 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+// Fused double gather:
+// Step 1: temp[i, j, :] = params[indices1[i, j], :] -> shape [I10, I11, P1]
+// Step 2: output[i, j, k, :] = temp[indices2[i, j], k, :] -> shape [I20, I21, I11, P1]
+// Fused: output[i, j, k, :] = params[indices1[indices2[i, j], k], :]
+template <typename Tindices1, typename Tindices2>
+static void FusedDoubleGatherImpl(OpKernelContext* context,
+ const float* params_data, const TensorShape& params_shape,
+ const Tindices1* indices1_data, const TensorShape& indices1_shape,
+ const Tindices2* indices2_data, const TensorShape& indices2_shape,
+ Tensor* output) {
+ // params shape: [P0, P1] (2D) or [P0, P1, P2] (3D)
+ // indices1 shape: [I10, I11] (2D) or [I10] (1D), values in [0, P0)
+ // indices2 shape: [I20, I21] (2D) or [I21] (1D), values in [0, I10) for 2D, [0, P0) for 1D
+ // output shape:
+ // 2D+2D: [I20, I21, I11, P_row]
+ // 1D+2D: [I20, I21, P_row]
+ // 2D+1D: [I21, I11, P_row]
+ // 1D+1D: [I21, P_row]
+ // where P_row = P1 (2D params) or P1*P2 (3D params)
+ OP_REQUIRES(context, params_shape.dims() >= 2 && params_shape.dims() <= 3,
+ errors::InvalidArgument("params must be 2D or 3D matrix"));
+ OP_REQUIRES(context, indices1_shape.dims() >= 1 && indices1_shape.dims() <= 2,
+ errors::InvalidArgument("indices1 must be 1D or 2D matrix"));
+ OP_REQUIRES(context, indices2_shape.dims() >= 1 && indices2_shape.dims() <= 2,
+ errors::InvalidArgument("indices2 must be 1D or 2D matrix"));
+
+ const int P0 = params_shape.dim_size(0);
+ // P_row: number of floats per params row (P1 for 2D, P1*P2 for 3D)
+ int64_t P_row = params_shape.dim_size(1);
+ if (params_shape.dims() == 3) {
+ P_row *= params_shape.dim_size(2);
+ }
+
+ // Handle indices1 dimensions: 1D [I10] or 2D [I10, I11]
+ const int indices1_dims = indices1_shape.dims();
+ const int I10 = indices1_shape.dim_size(0);
+ const int I11 = (indices1_dims == 2) ? indices1_shape.dim_size(1) : 1;
+
+ // Handle indices2 dimensions: 1D [I21] or 2D [I20, I21]
+ const int indices2_dims = indices2_shape.dims();
+ const int I20 = (indices2_dims == 2) ? indices2_shape.dim_size(0) : 1;
+ const int I21 = (indices2_dims == 2) ? indices2_shape.dim_size(1) : indices2_shape.dim_size(0);
+
+ // Build output shape based on indices dimensions
+ TensorShape output_shape;
+ if (indices2_dims == 2) {
+ output_shape.AddDim(I20);
+ }
+ output_shape.AddDim(I21);
+ if (indices1_dims == 2) {
+ output_shape.AddDim(I11);
+ }
+ output_shape.AddDim(P_row);
+
+ OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, output_shape, output));
+ VLOG(1) << "fused gather output shape: " << output->shape().DebugString();
+
+ float* output_data = output->flat<float>().data();
+
+ // Fused double gather (parallelized over the outer I20 * I21 work units):
+ // 2D+2D: output[i, j, k, :] = params[indices1[indices2[i, j], k], :]
+ // 1D+2D: output[i, j, :] = params[indices1[indices2[i, j]], :]
+ // 2D+1D: output[j, k, :] = params[indices1[indices2[j], k], :]
+ // 1D+1D: output[j, :] = params[indices1[indices2[j]], :]
+ //
+ // Each work unit (i, j) writes to a disjoint region of output, so no
+ // synchronization is needed between parallel tasks.
+ // cost_per_unit: 2D indices1 -> I11 * P_row floats copied; 1D -> P_row floats.
+ const int64_t total_units = static_cast<int64_t>(I20) * I21;
+ const int64_t cost_per_unit =
+ (indices1_dims == 2) ? static_cast<int64_t>(I11) * P_row
+ : P_row;
+
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ worker_threads->workers->ParallelFor(
+ total_units, cost_per_unit,
+ [&](int64_t begin, int64_t end) {
+ for (int64_t flat = begin; flat < end; ++flat) {
+ const int i = static_cast<int>(flat / I21);
+ const int j = static_cast<int>(flat % I21);
+
+ Tindices2 idx2 = indices2_data[(indices2_dims == 2) ? (i * I21 + j) : j];
+
+ if (indices1_dims == 2) {
+ // 2D indices1: idx2 indexes into first dimension (I10)
+ if (TF_PREDICT_FALSE(idx2 < 0 || idx2 >= I10)) {
+ context->CtxFailure(errors::InvalidArgument(
+ "FusedGather: indices2[",
+ (indices2_dims == 2) ? i : 0, ",", j, "]=", idx2,
+ " out of range [0, ", I10, ")"));
+ return;
+ }
+ for (int k = 0; k < I11; ++k) {
+ Tindices1 idx1 = indices1_data[idx2 * I11 + k];
+ if (TF_PREDICT_FALSE(idx1 < 0 || idx1 >= P0)) {
+ context->CtxFailure(errors::InvalidArgument(
+ "GatherV2 axis=0: index out of range"));
+ return;
+ }
+ int64_t output_offset;
+ if (indices2_dims == 2) {
+ // 2D+2D: [i, j, k, :]
+ output_offset = ((i * I21 + j) * I11 + k) * P_row;
+ } else {
+ // 2D+1D: [j, k, :]
+ output_offset = (j * I11 + k) * P_row;
+ }
+ std::memcpy(
+ output_data + output_offset,
+ params_data + idx1 * P_row,
+ sizeof(float) * P_row
+ );
+ }
+ } else {
+ // 1D indices1: idx2 indexes into indices1, then idx1 indexes into params
+ if (TF_PREDICT_FALSE(idx2 < 0 || idx2 >= I10)) {
+ context->CtxFailure(errors::InvalidArgument(
+ "FusedGather: indices2[",
+ (indices2_dims == 2) ? i : 0, ",", j, "]=", idx2,
+ " out of range [0, ", I10, ")"));
+ return;
+ }
+ Tindices1 idx1 = indices1_data[idx2];
+ if (TF_PREDICT_FALSE(idx1 < 0 || idx1 >= P0)) {
+ context->CtxFailure(errors::InvalidArgument(
+ "GatherV2 axis=0: index out of range"));
+ return;
+ }
+ int64_t output_offset;
+ if (indices2_dims == 2) {
+ // 1D+2D: [i, j, :]
+ output_offset = (i * I21 + j) * P_row;
+ } else {
+ // 1D+1D: [j, :]
+ output_offset = j * P_row;
+ }
+ std::memcpy(
+ output_data + output_offset,
+ params_data + idx1 * P_row,
+ sizeof(float) * P_row
+ );
+ }
+ }
+ });
+}
+
+
+template <typename Tindices1, typename Tindices2>
+class KPFusedEmbeddingActionIdGatherOp : public OpKernel {
+public:
+ explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& indices1 = context->input(0);
+ const Tensor& params = context->input(1);
+ const Tensor& indices2 = context->input(2);
+ const Tensor& pack_dim = context->input(3);
+
+ const Tensor& pack = context->input(4);
+
+ VLOG(1) << "indices1 shape: " << indices1.shape().DebugString();
+ VLOG(1) << "params shape: " << params.shape().DebugString();
+ VLOG(1) << "indices2 shape: " << indices2.shape().DebugString();
+ OP_REQUIRES(
+ context,
+ indices1.dims() >= 1 && indices1.dims() <= 2,
+ errors::InvalidArgument("indices1 dims must be 1 or 2")
+ );
+ OP_REQUIRES(
+ context,
+ indices2.dims() >= 1 && indices2.dims() <= 2,
+ errors::InvalidArgument("indices2 dims must be 1 or 2")
+ );
+ OP_REQUIRES(
+ context,
+ params.dims() >= 2 && params.dims() <= 3,
+ errors::InvalidArgument("params dims must = 2 or 3")
+ );
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsScalar(pack_dim.shape()),
+ errors::InvalidArgument("pack_dim is scalar")
+ );
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsScalar(pack.shape()),
+ errors::InvalidArgument("pack const is scalar")
+ );
+
+ // Fused double gather: directly compute params[indices1[indices2[i]]]
+ Tensor gathered;
+ FusedDoubleGatherImpl<Tindices1, Tindices2>(
+ context,
+ params.flat<float>().data(), params.shape(),
+ indices1.flat<Tindices1>().data(), indices1.shape(),
+ indices2.flat<Tindices2>().data(), indices2.shape(),
+ &gathered);
+ int pack_size = pack_dim.scalar<int32>()();
+ int pack_const = pack.scalar<int32>()();
+ OP_REQUIRES(context, pack_size > 0, errors::InvalidArgument("pack_size must > 0"));
+ int a_reshaped_cols = gathered.NumElements() / pack_size;
+ auto a_reshaped = gathered.shaped<float, 2>({pack_size, a_reshaped_cols});
+ Tensor* output;
+ int output_cols = a_reshaped_cols + pack_const;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({pack_size, output_cols}), &output));
+ auto a_reshaped_data = a_reshaped.data();
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64_t pack_cost_per_row =
+ static_cast<int64_t>(a_reshaped_cols) + pack_const;
+ worker_threads->workers->ParallelFor(
+ pack_size, pack_cost_per_row,
+ [&](int64_t start_row, int64_t end_row) {
+ float* base = output->matrix<float>().data();
+ for (int64_t row = start_row; row < end_row; ++row) {
+ float* dst_row = base + row * (a_reshaped_cols + pack_const);
+ std::memcpy(
+ dst_row, a_reshaped_data + row * a_reshaped_cols,
+ sizeof(float) * a_reshaped_cols
+ );
+ std::memset(
+ dst_row + a_reshaped_cols, 0, sizeof(float) * pack_const
+ );
+ }
+ });
+ }
+};
+
+#define REGISTER_CPU_KERNEL(Tindices1, Tindices2) \
+ REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingActionIdGather") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<Tindices1>("Tindices1") \
+ .TypeConstraint<Tindices2>("Tindices2"), \
+ KPFusedEmbeddingActionIdGatherOp<Tindices1, Tindices2>);
+
+REGISTER_CPU_KERNEL(int64, int32)
+REGISTER_CPU_KERNEL(int32, int32)
+REGISTER_CPU_KERNEL(int64, int64)
+REGISTER_CPU_KERNEL(int32, int64)
+
+}
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,369 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * ==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+class KPFusedEmbeddingActionIdGatherTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType indices1_type, DataType indices2_type) {
+ TF_ASSERT_OK(NodeDefBuilder("fused_embedding_action_id_gather",
+ "KPFusedEmbeddingActionIdGather")
+ .Input(FakeInput(indices1_type)) // indices1
+ .Input(FakeInput(DT_FLOAT)) // params
+ .Input(FakeInput(indices2_type)) // indices2
+ .Input(FakeInput(DT_INT32)) // pack_dim
+ .Input(FakeInput(DT_INT32)) // pack
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+
+ template <typename Tindices1, typename Tindices2>
+ Status FeedAndRun(const std::vector<Tindices1>& indices1_data,
+ const TensorShape& indices1_shape,
+ const std::vector<float>& params_data,
+ const TensorShape& params_shape,
+ const std::vector<Tindices2>& indices2_data,
+ const TensorShape& indices2_shape, int pack_dim_value,
+ int pack_value) {
+ inputs_.clear();
+ input_types_.clear();
+
+ MakeOp(DataTypeToEnum<Tindices1>::v(), DataTypeToEnum<Tindices2>::v());
+ AddInputFromArray<Tindices1>(indices1_shape, indices1_data);
+ AddInputFromArray<float>(params_shape, params_data);
+ AddInputFromArray<Tindices2>(indices2_shape, indices2_data);
+ AddInputFromArray<int32>(TensorShape({}), {pack_dim_value});
+ AddInputFromArray<int32>(TensorShape({}), {pack_value});
+ return RunOpKernel();
+ }
+};
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, NormalCase) {
+ std::vector<int64> indices1_data = {0, 2};
+ TensorShape indices1_shape({2, 1});
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TensorShape params_shape({3, 2});
+
+ std::vector<int32> indices2_data = {1, 0};
+ TensorShape indices2_shape({2, 1});
+
+ int pack_dim_value = 2;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, OneDIndices1AndOneDIndices2) {
+ std::vector<int64> indices1_data = {0, 1, 2};
+ TensorShape indices1_shape({3});
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TensorShape params_shape({3, 2});
+
+ std::vector<int32> indices2_data = {1, 0};
+ TensorShape indices2_shape({2});
+
+ int pack_dim_value = 2;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {3.0f, 4.0f, 0.0f, 1.0f, 2.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, OneDIndices1AndTwoDIndices2) {
+ std::vector<int64> indices1_data = {0, 2, 1};
+ TensorShape indices1_shape({3});
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TensorShape params_shape({3, 2});
+
+ std::vector<int32> indices2_data = {1, 0, 2, 1};
+ TensorShape indices2_shape({2, 2});
+
+ int pack_dim_value = 4;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
+ test::FillValues<float>(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f,
+ 3.0f, 4.0f, 0.0f, 5.0f, 6.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, TwoDIndices1AndOneDIndices2) {
+ std::vector<int64> indices1_data = {0, 2, 1, 0};
+ TensorShape indices1_shape({2, 2});
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ TensorShape params_shape({3, 2});
+
+ std::vector<int32> indices2_data = {1, 0};
+ TensorShape indices2_shape({2});
+
+ int pack_dim_value = 2;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+ test::FillValues<float>(&expected, {3.0f, 4.0f, 1.0f, 2.0f, 0.0f,
+ 1.0f, 2.0f, 5.0f, 6.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, DifferentIndexTypes) {
+ // int64int32
+ {
+ std::vector<int64> indices1 = {0, 2};
+ std::vector<int32> indices2 = {1, 0};
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(indices1, {2, 1},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {3, 2}, indices2, {2, 1}, 2, 1)));
+ test::ExpectTensorNear<float>(
+ *GetOutput(0),
+ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+ 1e-5);
+ }
+
+ // int32int32
+ {
+ std::vector<int32> indices1 = {0, 2};
+ std::vector<int32> indices2 = {1, 0};
+ TF_ASSERT_OK((FeedAndRun<int32, int32>(indices1, {2, 1},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {3, 2}, indices2, {2, 1}, 2, 1)));
+ test::ExpectTensorNear<float>(
+ *GetOutput(0),
+ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+ 1e-5);
+ }
+
+ // int64int64
+ {
+ std::vector<int64> indices1 = {0, 2};
+ std::vector<int64> indices2 = {1, 0};
+ TF_ASSERT_OK((FeedAndRun<int64, int64>(indices1, {2, 1},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {3, 2}, indices2, {2, 1}, 2, 1)));
+ test::ExpectTensorNear<float>(
+ *GetOutput(0),
+ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+ 1e-5);
+ }
+
+ // int32int64
+ {
+ std::vector<int32> indices1 = {0, 2};
+ std::vector<int64> indices2 = {1, 0};
+ TF_ASSERT_OK((FeedAndRun<int32, int64>(indices1, {2, 1},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {3, 2}, indices2, {2, 1}, 2, 1)));
+ test::ExpectTensorNear<float>(
+ *GetOutput(0),
+ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+ 1e-5);
+ }
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidParamsDims) {
+ MakeOp(DT_INT64, DT_INT32);
+
+ std::vector<int64> indices1_data = {0, 2};
+ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f};
+ AddInputFromArray<float>(TensorShape({4}), params_data);
+
+ std::vector<int32> indices2_data = {1, 0};
+ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+ AddInputFromArray<int32>(TensorShape({}), {2});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.ToString(), "params dims must = 2 or 3")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDimDims) {
+ MakeOp(DT_INT64, DT_INT32);
+
+ std::vector<int64> indices1_data = {0, 2};
+ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+ std::vector<int32> indices2_data = {1, 0};
+ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+ AddInputFromArray<int32>(TensorShape({1}), {2});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_dim is scalar")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDims) {
+ MakeOp(DT_INT64, DT_INT32);
+
+ std::vector<int64> indices1_data = {0, 2};
+ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+ std::vector<int32> indices2_data = {1, 0};
+ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+ AddInputFromArray<int32>(TensorShape({}), {2});
+ AddInputFromArray<int32>(TensorShape({1}), {1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack const is scalar")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackSize) {
+ MakeOp(DT_INT64, DT_INT32);
+
+ std::vector<int64> indices1_data = {0, 2};
+ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+ std::vector<int32> indices2_data = {1, 0};
+ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_size must > 0")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, IndexOutOfRange) {
+ MakeOp(DT_INT64, DT_INT32);
+
+ std::vector<int64> indices1_data = {0, 5};
+ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+ std::vector<int32> indices2_data = {1, 0};
+ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+ AddInputFromArray<int32>(TensorShape({}), {2});
+ AddInputFromArray<int32>(TensorShape({}), {1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(
+ absl::StrContains(s.ToString(), "GatherV2 axis=0: index out of range"))
+ << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, ThreeDParams2DIndices1And2DIndices2) {
+ std::vector<int64> indices1_data = {0, 1};
+ TensorShape indices1_shape({2, 1});
+
+ // params[0] = [[1,2],[3,4]], params[1] = [[5,6],[7,8]], stored row-major
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f};
+ TensorShape params_shape({2, 2, 2});
+
+ std::vector<int32> indices2_data = {1, 0};
+ TensorShape indices2_shape({2, 1});
+
+ int pack_dim_value = 2;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+ test::FillValues<float>(&expected,
+ {5.0f, 6.0f, 7.0f, 8.0f, 0.0f,
+ 1.0f, 2.0f, 3.0f, 4.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, ThreeDParams1DIndices1And1DIndices2) {
+ std::vector<int64> indices1_data = {0, 1, 2};
+ TensorShape indices1_shape({3});
+
+ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f,
+ 5.0f, 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f, 12.0f};
+ TensorShape params_shape({3, 2, 2});
+
+ std::vector<int32> indices2_data = {2, 0};
+ TensorShape indices2_shape({2});
+
+ int pack_dim_value = 2;
+ int pack_value = 1;
+
+ TF_ASSERT_OK((FeedAndRun<int64, int32>(
+ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+ indices2_shape, pack_dim_value, pack_value)));
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+ test::FillValues<float>(&expected,
+ {9.0f, 10.0f, 11.0f, 12.0f, 0.0f,
+ 1.0f, 2.0f, 3.0f, 4.0f, 0.0f});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,90 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+class KPFusedGather : public OpKernel {
+ public:
+ explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& data = context->input(0);
+ const Tensor& keys = context->input(1);
+ const Tensor& begin = context->input(2);
+ VLOG(1) << "Embedding table size: " << data.shape().DebugString();
+ VLOG(1) << "Input key shape: " << keys.shape().DebugString();
+ VLOG(1) << "Slice begin value: " << begin.DebugString();
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(keys.shape()),
+ errors::Internal("Input key must be 2D"));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(data.shape()),
+ errors::Internal("Embedding table shape must be 2D"));
+ OP_REQUIRES(context, begin.NumElements() == 2, errors::Internal("begin must be same as keys rank"));
+ int32 col = begin.flat<int32>().data()[1];
+ OP_REQUIRES(context, col < keys.dim_size(1), errors::Internal("slice cols out of keys range"));
+
+ Tensor* out_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 1, TensorShape({static_cast<int32>(keys.dim_size(0))}), &out_indices));
+ int32 *out_indices_data = out_indices->flat<int32>().data();
+
+ auto keys_mat = keys.matrix<int64>();
+ std::vector<int64_t> unique_values;
+ std::unordered_map<int64_t, int32_t> value_to_index;
+ int current_index = 0;
+ for (int64_t i = 0; i < keys.dim_size(0); ++i) {
+ auto it = value_to_index.find(keys_mat(i, col));
+ if (it == value_to_index.end()) {
+ value_to_index[keys_mat(i, col)] = current_index;
+ unique_values.push_back(keys_mat(i, col));
+ out_indices_data[i] = current_index;
+ ++current_index;
+ } else {
+ out_indices_data[i] = it->second;
+ }
+ }
+
+ Tensor* out_unique_value = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({static_cast<int32>(unique_values.size())}), &out_unique_value));
+ std::memcpy(out_unique_value->data(), unique_values.data(), unique_values.size() * sizeof(int64_t));
+
+ Tensor* out_data = nullptr;
+ int embedding_dims = data.dim_size(1);
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 2, TensorShape({static_cast<int32>(unique_values.size()), embedding_dims}), &out_data));
+
+ const float *data_mat = data.flat<float>().data();
+ for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) {
+ int64_t idx = unique_values[cur_row];
+ OP_REQUIRES(context, idx < data.dim_size(0) && idx >= 0, errors::Internal("idx out of table range"));
+ const float* src = data_mat + idx * embedding_dims;
+ float* dst = out_data->flat<float>().data() + cur_row * embedding_dims;
+ std::memcpy(dst, src, embedding_dims * sizeof(float));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedGather").Device(DEVICE_CPU),
+ KPFusedGather);
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,186 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::ExpectClose;
+using tensorflow::test::FillValues;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedGatherTest : public OpsTestBase {
+ protected:
+ void RunValidCase(const TensorShape& data_shape,
+ const TensorShape& slice_shape,
+ const std::vector<int32>& begin_val,
+ const std::vector<int64>& slice_data,
+ const std::vector<float>& data_data,
+ const std::vector<int64>& expected_unique,
+ const std::vector<int32>& expected_indices,
+ const std::vector<float>& expected_output_data) {
+ TF_EXPECT_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT64))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+
+ AddInputFromArray<float>(data_shape, data_data);
+ AddInputFromArray<int64>(slice_shape, slice_data);
+ AddInputFromArray<int32>(TensorShape({2}), begin_val);
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ const Tensor& out_unique = *GetOutput(0);
+ const Tensor& out_indices = *GetOutput(1);
+ const Tensor& out_data = *GetOutput(2);
+
+ // 验证输出0: unique_values
+ Tensor expected_unique_tensor(
+ allocator(), DT_INT64,
+ TensorShape({static_cast<int64>(expected_unique.size())})
+ );
+ FillValues<int64>(&expected_unique_tensor, expected_unique);
+ ExpectTensorEqual<int64>(expected_unique_tensor, out_unique);
+
+ // 验证输出1: indices
+ Tensor expected_indices_tensor(
+ allocator(), DT_INT32,
+ TensorShape({static_cast<int64_t>(expected_indices.size())})
+ );
+ FillValues<int32>(&expected_indices_tensor, expected_indices);
+ ExpectTensorEqual<int32>(expected_indices_tensor, out_indices);
+
+ // 验证输出2: out_data
+ Tensor expected_data_tensor(allocator(), DT_FLOAT,
+ TensorShape({static_cast<int64>(expected_unique.size()), 12}));
+ FillValues<float>(&expected_data_tensor, expected_output_data);
+ ExpectClose(expected_data_tensor, out_data); // float 用 ExpectClose
+ }
+
+ Status RunOpExpectFailure(const TensorShape& data_shape,
+ const TensorShape& slice_shape,
+ const std::vector<int32>& begin_val,
+ const std::vector<int64>& slice_data,
+ const std::vector<float>& data_data) {
+ TF_CHECK_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT64))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+
+ AddInputFromArray<float>(data_shape, data_data);
+ AddInputFromArray<int64>(slice_shape, slice_data);
+ AddInputFromArray<int32>(TensorShape({2}), begin_val);
+
+ return RunOpKernel();
+ }
+};
+
+// 正向测试:正常输入
+TEST_F(KPFusedGatherTest, Valid_NormalInput) {
+ RunValidCase(
+ TensorShape({2, 12}), // data shape
+ TensorShape({4, 3}), // slice_input shape
+ {0, 1}, // begin[1] = 1 → 取第1列
+ {1, 1, 3,
+ 0, 1, 5,
+ 1, 0, 7,
+ 0, 1, 9}, // slice_input 数据
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f},
+ {1, 0}, // unique values from col=1
+ {0, 0, 1, 0}, // indices mapping
+ {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, // data[1]
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} // data[0]
+ );
+}
+
+// data不是2维
+TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) {
+ std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f};
+ Status s = RunOpExpectFailure(
+ TensorShape({4}), // data 不是二维
+ TensorShape({2, 2}),
+ {0, 0},
+ {0, 1, 2, 3},
+ data
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Embedding table shape must be 2D"));
+}
+
+// key 不是2维
+TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) {
+ std::vector<float> data(2 * 12, 1.0f);
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 12}),
+ TensorShape({4}), // 1D slice_input
+ {0, 0},
+ {0, 1, 2, 3},
+ data
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input key must be 2D"));
+}
+
+// begin[1] 超出列范围
+TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) {
+ std::vector<float> data(2 * 12, 1.0f);
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 12}),
+ TensorShape({2, 2}),
+ {0, 2}, // begin[1] = 2,但只有 2 列 → 索引 0,1
+ {0, 1, 2, 3},
+ data
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),"slice cols out of keys range"));
+}
+
+// gather 索引超出 data 行数
+TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) {
+ std::vector<float> data(2 * 12, 1.0f);
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 12}),
+ TensorShape({2, 2}),
+ {0, 0},
+ {0, 1,
+ 2, 3}, // 索引 2 超出 data 行数(只有 0,1)
+ data
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),"idx out of table range"));
+}
+
+}
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,126 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+class KPFusedEmbeddingPaddingOp : public OpKernel {
+public:
+ explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) {
+ fast_ = (type_string() == "KPFusedEmbeddingPaddingFast");
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& origin_shape = context->input(0);
+ const Tensor& input = context->input(1);
+ const Tensor& input_rows = context->input(2);
+ const Tensor& reshape_sizes = context->input(3);
+
+ const Tensor& pack = context->input(4);
+
+ VLOG(1) << "Input shape: " << input.shape().DebugString();
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsVector(origin_shape.shape()),
+ errors::InvalidArgument("origin_shape dims must 1D, not ", origin_shape.shape().DebugString())
+ );
+ OP_REQUIRES(context,
+ origin_shape.NumElements() == 2,
+ errors::InvalidArgument("origin_shape NumElements must == 2, not ", origin_shape.NumElements())
+ );
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsMatrix(input.shape()),
+ errors::InvalidArgument("input dims must 2D, not ", input.shape().DebugString()));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(input_rows.shape()),
+ errors::InvalidArgument("input_rows must be a scalar")
+ );
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsVector(reshape_sizes.shape()),
+ errors::InvalidArgument("sizes input must be 1-D, not ", reshape_sizes.shape().DebugString())
+ );
+ OP_REQUIRES(context,
+ reshape_sizes.NumElements() == 2,
+ errors::InvalidArgument("reshape_sizes NumElements must == 2"));
+
+ int input_rows_value = input_rows.scalar<int32>()();
+ int padding_rows = static_cast<int32>(origin_shape.flat<int64>()(0)) - input_rows_value;
+ auto reshape_cols = reshape_sizes.flat<int32>()(1);
+ int output_rows = padding_rows + input.dim_size(0);
+ int output_cols = input.dim_size(1);
+ OP_REQUIRES(context,
+ padding_rows >= 0,
+ errors::InvalidArgument("Pooling size(", input_rows_value,
+ ") is greater than Input size(", static_cast<int32>(origin_shape.flat<int64>()(0)), ")"));
+ OP_REQUIRES(context,
+ reshape_cols > 0,
+ errors::InvalidArgument("reshape_cols must > 0"));
+ OP_REQUIRES(context,
+ reshape_sizes.flat<int32>()(0) == -1,
+ errors::InvalidArgument("reshape[0] is not -1"));
+ OP_REQUIRES(context,
+ pack.scalar<int32>()() == output_cols,
+ errors::InvalidArgument("pack(", pack.scalar<int32>()(), ") is not equal to embedding dims"));
+
+ Tensor* output0 = nullptr;
+ Tensor* output1 = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}),
+ &output0));
+ output0->scalar<int32>()() = padding_rows;
+ OP_REQUIRES(context,
+ output_rows * output_cols % reshape_cols == 0,
+ errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]")
+ );
+ int reshape_rows = output_rows * output_cols / reshape_cols;
+ if (fast_) {
+ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1));
+ output1->scalar<int32>()() = reshape_rows;
+ return;
+ }
+
+ TensorShape reshaped_shape({reshape_rows, reshape_cols});
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, reshaped_shape, &output1));
+ float* output_data = output1->flat<float>().data();
+ const float* input_data = input.flat<float>().data();
+ std::memcpy(output_data, input_data, input.dim_size(0) * output_cols * sizeof(float));
+ std::memset(output_data + input.dim_size(0) * output_cols,
+ 0.0f,
+ padding_rows * output_cols * sizeof(float));
+ }
+
+private:
+ bool fast_;
+};
+
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPadding").Device(DEVICE_CPU),
+ KPFusedEmbeddingPaddingOp);
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPaddingFast").Device(DEVICE_CPU),
+ KPFusedEmbeddingPaddingOp);
+
+}
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,307 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+class KPFusedEmbeddingPaddingTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
+ TF_ASSERT_OK(NodeDefBuilder("fused_padding", "KPFusedEmbeddingPadding")
+ .Input(FakeInput(input_shape_type))
+ .Input(FakeInput(pooling_type))
+ .Input(FakeInput(const_type))
+ .Input(FakeInput(reshape_type))
+ .Input(FakeInput(const_type))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+
+ Status FeedAndRun(const int embedding_dims, const int table_size,
+ const int pooling_size, const int reshape_size) {
+ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
+ return static_cast<float>(i + 1);
+ });
+ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+ return RunOpKernel();
+ }
+
+ void MakeFastOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
+ TF_ASSERT_OK(NodeDefBuilder("fused_padding_fast", "KPFusedEmbeddingPaddingFast")
+ .Input(FakeInput(input_shape_type))
+ .Input(FakeInput(pooling_type))
+ .Input(FakeInput(const_type))
+ .Input(FakeInput(reshape_type))
+ .Input(FakeInput(const_type))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+
+ Status FeedAndRunFast(const int embedding_dims, const int table_size,
+ const int pooling_size, const int reshape_size) {
+ MakeFastOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
+ return static_cast<float>(i + 1);
+ });
+ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+ return RunOpKernel();
+ }
+};
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_0) {
+ // Feed and run
+ const int embedding_dims = 10;
+ const int table_size = 151;
+ const int pooling_size = 151;
+ const int reshape_size = 1510;
+ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillFn<float>(&expected2, [=](int i) -> float {
+ if (i < pooling_size * embedding_dims) {
+ return static_cast<float>(i + 1);
+ } else {
+ return 0.0f;
+ }
+ });
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_1) {
+ // Feed and run
+ const int embedding_dims = 10;
+ const int table_size = 1510;
+ const int pooling_size = 151;
+ const int reshape_size = 1510;
+ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillFn<float>(&expected2, [=](int i) -> float {
+ if (i < pooling_size * embedding_dims) {
+ return static_cast<float>(i + 1);
+ } else {
+ return 0.0f;
+ }
+ });
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_0) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 2;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillFn<float>(&expected2, [=](int i) -> float {
+ if (i < pooling_size * embedding_dims) {
+ return static_cast<float>(i + 1);
+ } else {
+ return 0.0f;
+ }
+ });
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_1) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 200;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillFn<float>(&expected2, [=](int i) -> float {
+ if (i < pooling_size * embedding_dims) {
+ return static_cast<float>(i + 1);
+ } else {
+ return 0.0f;
+ }
+ });
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_0) {
+ // Feed and run
+ const int embedding_dims = 10;
+ const int table_size = 151;
+ const int pooling_size = 151;
+ const int reshape_size = 1510;
+ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_1) {
+ // Feed and run
+ const int embedding_dims = 10;
+ const int table_size = 1510;
+ const int pooling_size = 151;
+ const int reshape_size = 1510;
+ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_0) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 2;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_1) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 200;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+ // Check the output.
+ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected1, {table_size - pooling_size});
+ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectReshape) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 200;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
+ return static_cast<float>(i + 1);
+ });
+ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+ AddInputFromArray<int32>(TensorShape({2}), {10, reshape_size});
+ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ absl::StrContains(s.ToString(), "reshape[0] is not -1"))
+ << s;
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectPack) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 200;
+ const int pooling_size = 2;
+ const int reshape_size = 24;
+ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
+ return static_cast<float>(i + 1);
+ });
+ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+ AddInputFromArray<int32>(TensorShape({}), {10});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ absl::StrContains(s.ToString(), "pack(10) is not equal to embedding dims"))
+ << s;
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithPoolingSizeGreaterInput) {
+ // Feed and run
+ const int embedding_dims = 12;
+ const int table_size = 200;
+ const int pooling_size = 201;
+ const int reshape_size = 24;
+ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
+ return static_cast<float>(i + 1);
+ });
+ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ absl::StrContains(s.ToString(), "Pooling size(201) is greater than Input size(200)"))
+ << s;
+}
+
+} // end namespace tensorflow
new file mode 100644
@@ -0,0 +1,87 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+class KPFusedSparseDynamicStitchOp : public OpKernel {
+public:
+ explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& x = context->input(0);
+ auto x_flat = x.flat<int64>();
+ int64_t num_elems = x_flat.size();
+
+ const int num_inputs = context->num_inputs();
+ const int num_partitions = num_inputs - 1;
+ OP_REQUIRES(context, num_partitions > 1, errors::InvalidArgument("num partitions must > 1"));
+ int64_t output_stride = 0;
+ std::vector<const float*> variables(num_partitions);
+ std::vector<int64_t> variable_rows(num_partitions);
+ for (int i = 1; i < num_inputs; ++i) {
+ const Tensor& input_tensor = context->input(i);
+ OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input dims must == 2"));
+ if (i == 1) {
+ output_stride = input_tensor.dim_size(1);
+ } else {
+ OP_REQUIRES(context, input_tensor.dim_size(1) == output_stride,
+ errors::InvalidArgument("All inputs must have same second dimension"));
+ }
+ variables[i - 1] = context->input(i).flat<float>().data();
+ variable_rows[i - 1] = input_tensor.dim_size(0);
+ }
+
+ OP_REQUIRES(context, output_stride > 0, errors::InvalidArgument("output_stride must > 0"));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({num_elems, output_stride}),
+ &output_tensor));
+ float* output = (float*)output_tensor->tensor_data().data();
+
+ const size_t copy_size = output_stride * sizeof(float);
+
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64 cost_per_unit = 120; // Actual single cycle execution time
+ auto work = [&](int start, int end) {
+ for (int i = start; i < end; ++i) {
+ const int64_t global_id = x_flat(i);
+ const int64_t table_id = global_id % num_partitions;
+ const int64_t row_id = global_id / num_partitions;
+
+ OP_REQUIRES(context, row_id < variable_rows[table_id] && row_id >= 0, errors::InvalidArgument(
+ "row_id out of range."));
+
+ std::memcpy(output + i * output_stride,
+ variables[table_id] + row_id * output_stride, copy_size);
+ }
+ };
+
+ Shard(worker_threads->num_threads, worker_threads->workers, num_elems,
+ cost_per_unit, work);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU),
+ KPFusedSparseDynamicStitchOp);
new file mode 100644
@@ -0,0 +1,108 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseDynamicStitchOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int N) {
+ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_dynamic_stitch",
+ "KPFusedSparseDynamicStitch")
+ .Input(FakeInput(DT_INT64))
+ .Input(FakeInput(N, DT_FLOAT))
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestTwoTables) {
+ MakeOp(2); // num_partitions = 2
+
+ AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
+ AddInputFromArray<float>(TensorShape({3, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+ test::FillValues<float>(&expected,
+ {1.0f, 2.0f, 9.0f, 10.0f, 3.0f, 4.0f, 7.0f, 8.0f});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestDifferentStride) {
+ MakeOp(2);
+
+ AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
+ AddInputFromArray<float>(TensorShape({3, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ AddInputFromArray<float>(TensorShape({1, 4}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),"All inputs must have same second dimension"));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestIndicesOutOfBounds) {
+ MakeOp(2);
+
+ AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
+ AddInputFromArray<float>(TensorShape({3, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),"row_id out of range"));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestInputDims) {
+ MakeOp(2);
+
+ AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
+ AddInputFromArray<float>(TensorShape({3, 2, 1}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ AddInputFromArray<float>(TensorShape({2, 2, 1}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),"input dims must == 2"));
+}
+
+} // namespace
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,203 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "tensorflow/core/kernels/reshape_util.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+using namespace tensorflow;
+
+static void ReshapeKp(OpKernelContext *context, const Tensor &input_indices_in,
+ const Tensor &input_shape_in, const Tensor &target_shape_in,
+ int output_indices_idx, int output_shape_idx) {
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
+ errors::InvalidArgument(
+ "Input indices should be a matrix but received shape ",
+ input_indices_in.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
+ errors::InvalidArgument(
+ "Input shape should be a vector but received shape ",
+ input_shape_in.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
+ errors::InvalidArgument(
+ "Target shape should be a vector but received shape ",
+ target_shape_in.shape().DebugString()));
+
+ const int64 input_rank = input_shape_in.NumElements();
+ const int64 output_rank = target_shape_in.NumElements();
+ const TensorShape input_shape(input_shape_in.vec<int64>());
+ const int64 dense_size = input_shape.num_elements();
+ const int64 nnz = input_indices_in.shape().dim_size(0);
+
+ TensorShape output_shape;
+ int64 product = 1;
+ int unknown_index = -1;
+ auto target_shape = target_shape_in.vec<int64>();
+ for (int d = 0; d < output_rank; ++d) {
+ const int64 size = target_shape(d);
+ if (size == -1) {
+ OP_REQUIRES(
+ context, unknown_index == -1,
+ errors::InvalidArgument("only one output dimension may be -1, "
+ "not both ",
+ unknown_index, " and ", d));
+ unknown_index = d;
+ output_shape.AddDim(1);
+ } else {
+ OP_REQUIRES(context, size >= 0,
+ errors::InvalidArgument("size ", d,
+ " must be non-negative, not ", size));
+ product *= size;
+ output_shape.AddDim(size);
+ }
+ }
+ if (unknown_index != -1) {
+ OP_REQUIRES(
+ context, product > 0,
+ errors::InvalidArgument("reshape cannot infer the missing "
+ "input size for an empty tensor unless all "
+ "specified input sizes are non-zero"));
+ const int64 missing = dense_size / product;
+ OP_REQUIRES(
+ context, product * missing == dense_size,
+ errors::InvalidArgument(
+ "Input to reshape is a SparseTensor with ", dense_size,
+ " dense values, but the requested shape requires a multiple of ",
+ product, ". input_shape=", input_shape.DebugString(),
+ " output_shape=", output_shape.DebugString()));
+ output_shape.set_dim(unknown_index, missing);
+ }
+
+ OP_REQUIRES(
+ context, output_shape.num_elements() == dense_size,
+ errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
+ " dense values, but the requested shape has ",
+ output_shape.num_elements(),
+ ". input_shape=", input_shape.DebugString(),
+ " output_shape=", output_shape.DebugString()));
+
+ if (input_shape == output_shape) {
+ context->set_output(output_indices_idx, input_indices_in);
+ context->set_output(output_shape_idx, input_shape_in);
+ return;
+ }
+
+ gtl::InlinedVector<int64, 8> input_strides(input_rank);
+ if (input_rank > 0) {
+ input_strides[input_rank - 1] = 1;
+ for (int d = input_rank - 2; d >= 0; --d) {
+ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+ }
+ }
+
+ gtl::InlinedVector<int64, 8> output_strides(output_rank);
+ if (output_rank > 0) {
+ output_strides[output_rank - 1] = 1;
+ for (int d = output_rank - 2; d >= 0; --d) {
+ output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
+ }
+ }
+
+ Tensor *result_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(output_indices_idx,
+ TensorShape({nnz, output_rank}),
+ &result_indices));
+ auto input_ind = input_indices_in.matrix<int64>();
+ auto output_ind = result_indices->matrix<int64>();
+ for (int i = 0; i < nnz; ++i) {
+ int64 id = 0;
+ for (int j = 0; j < input_rank; ++j) {
+ id += input_ind(i, j) * input_strides[j];
+ }
+ for (int j = 0; j < output_rank; ++j) {
+ output_ind(i, j) = id / output_strides[j];
+ id %= output_strides[j];
+ }
+ }
+
+ Tensor *result_shape = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
+ TensorShape({output_rank}),
+ &result_shape));
+ auto output_shape_vec = result_shape->vec<int64>();
+ for (int j = 0; j < output_shape.dims(); ++j) {
+ output_shape_vec(j) = output_shape.dim_size(j);
+ }
+}
+
+template <typename T>
+class KPFusedSparseReshapeOp : public OpKernel {
+ public:
+ explicit KPFusedSparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) { }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& slice_input = context->input(0);
+ const Tensor& begin = context->input(1);
+ const Tensor& new_shape = context->input(2);
+ const Tensor& pack_const = context->input(3);
+
+ OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2"));
+ OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2"));
+ OP_REQUIRES(context, pack_const.dims() == 0,
+ errors::InvalidArgument("pack_const must be a scalar"));
+ VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString();
+ VLOG(1) << "Input begin value: " << begin.DebugString();
+ VLOG(1) << "Input new_shape value: " << new_shape.DebugString();
+
+ OP_REQUIRES(context, begin.dims() == 1 && begin.dim_size(0) == 2,
+ errors::InvalidArgument("begin must be 1D with at least 2 elements"));
+ int32 col = begin.flat<int32>().data()[1];
+ OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)"));
+ int64_t num_rows = slice_input.dim_size(0);
+ auto slice_input_mat = slice_input.matrix<int64>();
+
+ VLOG(1) << "num_rows: " << num_rows;
+ VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0);
+ VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1);
+ VLOG(1) << "Column index from begin: " << col;
+
+ Tensor shape_in(DT_INT64, TensorShape({2}));
+ auto tensor_flat = shape_in.flat<int64>();
+ tensor_flat(0) = num_rows;
+ tensor_flat(1) = static_cast<int64>(pack_const.scalar<T>()());
+
+ Tensor indices_in(DT_INT64, TensorShape({num_rows, 2}));
+ auto indices_in_mat = indices_in.matrix<int64>();
+ for (int i = 0; i < num_rows; ++i) {
+ indices_in_mat(i, 0) = i;
+ indices_in_mat(i, 1) = slice_input_mat(i, col);
+ }
+
+ ReshapeKp(context, indices_in, shape_in, new_shape, 0, 1);
+ }
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("KPFusedSparseReshape") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T"), \
+ KPFusedSparseReshapeOp<type>)
+
+REGISTER_KERNEL(int64);
+REGISTER_KERNEL(int32);
+#undef REGISTER_KERNEL
new file mode 100644
@@ -0,0 +1,281 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::FillValues;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedSparseReshapeTest : public OpsTestBase {
+ protected:
+ void RunValidCase(const TensorShape& slice_shape,
+ const std::vector<int64>& slice_data,
+ const std::vector<int32>& begin_val,
+ const std::vector<int64>& new_shape_val,
+ const std::vector<int64>& pack_const_val,
+ const TensorShape& expected_indices_shape,
+ const std::vector<int64>& expected_shape_val) {
+ TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+ .Input(FakeInput(DT_INT64)) // slice_input
+ .Input(FakeInput(DT_INT32)) // begin
+ .Input(FakeInput(DT_INT64)) // new_shape
+ .Input(FakeInput(DT_INT64)) // pack_const
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+
+ AddInputFromArray<int64>(slice_shape, slice_data);
+ AddInputFromArray<int32>(TensorShape({2}), begin_val);
+ AddInputFromArray<int64>(TensorShape({2}), new_shape_val);
+ AddInputFromArray<int64>(TensorShape({}), pack_const_val);
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ // 输出0: result_indices
+ const Tensor& out_indices = *GetOutput(0);
+ EXPECT_EQ(out_indices.shape(), expected_indices_shape);
+
+ // 输出1: result_shape
+ const Tensor& out_shape = *GetOutput(1);
+ Tensor expected_shape_tensor(DT_INT64,
+ TensorShape({static_cast<int64>(expected_shape_val.size())}));
+ FillValues<int64>(&expected_shape_tensor, expected_shape_val);
+ ExpectTensorEqual<int64>(expected_shape_tensor, out_shape);
+ }
+
+ Status RunOpExpectFailure(const TensorShape& slice_shape,
+ const std::vector<int64>& slice_data,
+ const std::vector<int32>& begin_val,
+ const std::vector<int64>& new_shape_val,
+ const std::vector<int64>& pack_const_val) {
+ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+ .Input(FakeInput(DT_INT64)) // slice_input
+ .Input(FakeInput(DT_INT32)) // begin
+ .Input(FakeInput(DT_INT64)) // new_shape
+ .Input(FakeInput(DT_INT64)) // pack_const
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+
+ AddInputFromArray<int64>(slice_shape, slice_data);
+ AddInputFromArray<int32>(TensorShape({static_cast<int64>(begin_val.size())}), begin_val);
+ AddInputFromArray<int64>(TensorShape({static_cast<int64>(new_shape_val.size())}), new_shape_val);
+ AddInputFromArray<int64>(TensorShape({}), pack_const_val);
+
+ return RunOpKernel();
+ }
+};
+
+// ==================== 正向测试 ====================
+
+// 正常 reshape 案例
+// pack_const=2
+TEST_F(KPFusedSparseReshapeTest, Valid_NormalInput) {
+ RunValidCase(
+ TensorShape({4, 2}), // slice_input shape
+ {0, 1,
+ 1, 2,
+ 2, 3,
+ 3, 0}, // slice_input 数据
+ {0, 1}, // begin = (0,1),选第1列
+ {2, 4}, // new_shape = [2,4]
+ {2}, // pack_const = [2]
+ TensorShape({4, 2}), // 预期 indices 形状
+ {2, 4}); // 预期 shape
+}
+
+// pack_const = 1
+TEST_F(KPFusedSparseReshapeTest, Valid_PackConst1) {
+ RunValidCase(
+ TensorShape({1, 2}), // slice_input shape
+ {0, 1}, // slice_input 数据
+ {0, 1}, // begin = (0,1),选第1列
+ {-1, 1}, // new_shape = [-1,1]
+ {1}, // pack_const = [1]
+ TensorShape({1, 2}), // 预期 indices 形状
+ {1, 1}); // 预期 shape
+}
+
+// ==================== 反向测试 ====================
+
+// 反例1:slice_input 不是二维
+TEST_F(KPFusedSparseReshapeTest, Invalid_SliceInputNot2D) {
+ Status s = RunOpExpectFailure(
+ TensorShape({4}), {0, 1, 2, 3},
+ {0, 0},
+ {2, 2},
+ {4});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "slice_input dims must == 2"));
+}
+
+// 反例2:new_shape dim size 不是 2
+TEST_F(KPFusedSparseReshapeTest, Invalid_NewShapeNotLen2) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 0},
+ {4, 2, 1}, // new_shape 多了1个元素
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "new_shape dim size must == 2"));
+}
+
+// 反例3:begin[1] 超出 slice_input 列数
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginOutOfRange) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 2}, // 超过列数
+ {2, 2},
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin[1] must < slice_input.dim_size(1)"));
+}
+
+// 反例4:target shape 有多个 -1
+TEST_F(KPFusedSparseReshapeTest, Invalid_MultipleUnknownDims) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 1},
+ {-1, -1}, // 两个 -1
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "only one output dimension may be -1"));
+}
+
+// 反例5:reshape 推断维度时,总元素数不能整除,导致无法匹配 --> product * missing != dense_size
+TEST_F(KPFusedSparseReshapeTest, Invalid_InferredShapeDoesNotMatch) {
+ TensorShape input_indices_shape({6, 2}); // 6 个非零元素,rank=2
+ std::vector<int64> input_indices_data = {
+ 0, 0,
+ 0, 1,
+ 0, 2,
+ 1, 0,
+ 1, 1,
+ 1, 2
+ }; // 对应 2x3 的 dense tensor
+
+ std::vector<int32> begin_val = {0, 0}; // 假设的 begin 输入
+ std::vector<int64> new_shape_val = {-1, 4}; // reshape 到 ?x4
+ std::vector<int64> pack_const_val = {1};
+
+ Status s = RunOpExpectFailure(
+ input_indices_shape,
+ input_indices_data,
+ begin_val,
+ new_shape_val,
+ pack_const_val);
+
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a SparseTensor with"));
+}
+
+// 反例6:reshape 后元素数量不匹配 --> output_shape.num_elements() != dense_size
+TEST_F(KPFusedSparseReshapeTest, Invalid_SizeMismatch) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 1},
+ {3, 3}, // 期望 9 元素,但输入 dense size = 4
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a tensor with"));
+}
+
+// 反例7:target_shape 包含负数但不是 -1
+TEST_F(KPFusedSparseReshapeTest, Invalid_NegativeDimNotMinusOne) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 0},
+ {2, -2}, // -2 是非法的
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "size 1 must be non-negative, not -2"))
+ << "Actual error: " << s.message();
+}
+
+// 反例8:target_shape 有 -1,但其他维度乘积为 0
+TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) {
+ // dense_size = 0(空 SparseTensor),target_shape = [-1, 0]
+ // product = 0 → 不允许 infer
+ Status s = RunOpExpectFailure(
+ TensorShape({0, 2}), {}, // 空的 slice_input
+ {0, 0},
+ {-1, 0}, // product = 0
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "reshape cannot infer the missing input size for an empty tensor"))
+ << "Actual error: " << s.message();
+}
+
+// 反例9:begin 是 1D 但长度为 1(不够 2 个元素)
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize1) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0}, // begin = [0],长度为 1
+ {2, 2},
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
+ << "Actual error: " << s.message();
+}
+
+// 反例10:begin 是 1D 但长度为 3(超过 2)
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize3) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2, 2}), {0, 1, 1, 0},
+ {0, 1, 2}, // begin = [0,1,2],长度为 3
+ {2, 2},
+ {2});
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
+ << "Actual error: " << s.message();
+}
+
+// 反例11:pack_const 是标量(0维)
+TEST_F(KPFusedSparseReshapeTest, Invalid_PackConstIsScalarButExpect1D) {
+ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+ .Input(FakeInput(DT_INT64)) // slice_input
+ .Input(FakeInput(DT_INT32)) // begin
+ .Input(FakeInput(DT_INT64)) // new_shape
+ .Input(FakeInput(DT_INT64)) // pack_const
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+
+ AddInputFromArray<int64>(TensorShape({2, 2}), {0, 1, 1, 0});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+ AddInputFromArray<int64>(TensorShape({2}), {2, 2});
+ AddInputFromArray<int64>(TensorShape({1}), {1}); // pack_const = 标量 1(0维)
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "pack_const must be a scalar"))
+ << "Actual error: " << s.message();
+}
+
+} // namespace
new file mode 100644
@@ -0,0 +1,165 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+template <typename Tidx>
+class KPFusedSparseSegmentReduceOp : public OpKernel {
+public:
+ explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int combiner_mode;
+ OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
+ OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
+ errors::InvalidArgument("combiner must be 0 or 1"));
+ is_mean_ = (combiner_mode == 1);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_tensor = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& slice_input = context->input(2);
+ const Tensor& begin = context->input(3);
+ const Tensor& begin_1 = context->input(4);
+
+ OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input must be 2-D"));
+ OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
+ OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements"));
+ OP_REQUIRES(context, begin_1.NumElements() == 1, errors::InvalidArgument("begin_1 must have 1 element"));
+ int64_t num_indices = indices.dim_size(0);
+ int64_t embedding_size = input_tensor.dim_size(1);
+ int32 col = begin.flat<int32>().data()[1];
+ int32 out_dim = static_cast<int32>(begin_1.flat<int32>()(0));
+
+ OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1),
+ errors::InvalidArgument("Column index out of range"));
+ OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
+ errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size"));
+
+ auto input_data = input_tensor.matrix<float>().data();
+ auto indices_vec = indices.vec<Tidx>();
+ auto slice_input_mat = slice_input.matrix<int64>();
+
+ // Calculate max segment_id
+ int64 max_seg_id = 0;
+ for (int32 i = 0; i < num_indices; ++i) {
+ int64 seg_id = slice_input_mat(i, col);
+ if (seg_id > max_seg_id) {
+ max_seg_id = seg_id;
+ }
+ }
+ const int64 batch_size = max_seg_id + 1;
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({batch_size, embedding_size}), &output));
+ output->flat<float>().setZero();
+ Tensor* slice_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, TensorShape({}), &slice_out));
+ if (out_dim == 0)
+ slice_out->scalar<int32>()() = batch_size;
+ else slice_out->scalar<int32>()() = embedding_size;
+
+ auto output_data = output->matrix<float>().data();
+
+ if (is_mean_) {
+ Tensor counts(DT_INT32, TensorShape({batch_size}));
+ counts.flat<int32>().setZero();
+ auto counts_vec = counts.flat<int32>();
+
+ for (int64 i = 0; i < num_indices; ++i) {
+ const int64 seg_id = slice_input_mat(i, col);
+ const Tidx data_row = indices_vec(i);
+ counts_vec(seg_id) += 1;
+
+ float* output_row = output_data + seg_id * embedding_size;
+ const float* input_data_row = input_data + data_row * embedding_size;
+ int64 j = 0;
+ for (; j + 3 < embedding_size; j += 4) {
+ float32x4_t out = vld1q_f32(output_row + j);
+ float32x4_t data = vld1q_f32(input_data_row + j);
+ out = vaddq_f32(out, data);
+ vst1q_f32(output_row + j, out);
+ }
+
+ for (; j < embedding_size; ++j) {
+ output_row[j] += input_data_row[j];
+ }
+ }
+
+ for (int64_t seg = 0; seg < batch_size; ++seg) {
+ const int32_t count = counts_vec(seg);
+ if (count > 0) {
+ const float inv_count = 1.0f / static_cast<float>(count);
+ const float32x4_t inv_count_vec = vdupq_n_f32(inv_count);
+
+ float* row_start = output_data + seg * embedding_size;
+ int64_t j = 0;
+
+ for (; j + 3 < embedding_size; j += 4) {
+ float32x4_t val = vld1q_f32(row_start + j);
+ val = vmulq_f32(val, inv_count_vec);
+ vst1q_f32(row_start + j, val);
+ }
+
+ for (; j < embedding_size; ++j) {
+ row_start[j] *= inv_count;
+ }
+ }
+ }
+ } else {
+ for (int64 i = 0; i < num_indices; ++i) {
+ const int64 seg_id = slice_input_mat(i, col);
+ const Tidx data_row = indices_vec(i);
+
+ float* output_row = output_data + seg_id * embedding_size;
+ const float* input_data_row = input_data + data_row * embedding_size;
+ int64 j = 0;
+ for (; j + 3 < embedding_size; j += 4) {
+ float32x4_t out = vld1q_f32(output_row + j);
+ float32x4_t data = vld1q_f32(input_data_row + j);
+ out = vaddq_f32(out, data);
+ vst1q_f32(output_row + j, out);
+ }
+
+ for (; j < embedding_size; ++j) {
+ output_row[j] += input_data_row[j];
+ }
+ }
+ }
+ }
+
+private:
+ bool is_mean_;
+};
+
+#define REGISTER_KERNEL(Tidx) \
+ REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduce") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<Tidx>("Tidx"), \
+ KPFusedSparseSegmentReduceOp<Tidx>);
+REGISTER_KERNEL(int64)
+REGISTER_KERNEL(int32)
+#undef REGISTER_KERNEL
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,387 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "absl/container/flat_hash_map.h"
+
+using namespace tensorflow;
+
+template <typename Tidx>
+class KPFusedSparseSegmentReduceNonzeroOp : public OpKernel {
+public:
+ explicit KPFusedSparseSegmentReduceNonzeroOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int combiner_mode;
+ OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
+ OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
+ errors::InvalidArgument("combiner must be 0 or 1"));
+ is_mean_ = (combiner_mode == 1);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_tensor = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& slice_input = context->input(2);
+ const Tensor& begin = context->input(3);
+
+ const int input_dims = input_tensor.dims();
+ OP_REQUIRES(context, input_dims == 1 || input_dims == 2,
+ errors::InvalidArgument("Input data must be a 1-D vector or 2-D matrix"));
+ OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
+ OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements"));
+
+ int64 num_indices = indices.dim_size(0);
+ int32 col = begin.flat<int32>().data()[1];
+
+ OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1),
+ errors::InvalidArgument("Column index out of range"));
+ OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
+ errors::InvalidArgument("indices and slice_input.dim_size(0) should have same size"));
+
+ auto indices_vec = indices.vec<Tidx>();
+ auto slice_input_mat = slice_input.matrix<int64>();
+
+ // Calculate max segment_id
+ std::vector<int64> segment_ids(num_indices);
+ int64 max_seg_id = 0;
+ for (int64 i = 0; i < num_indices; ++i) {
+ int64 seg_id = slice_input_mat(i, col);
+ segment_ids[i] = seg_id;
+ if (seg_id > max_seg_id) {
+ max_seg_id = seg_id;
+ }
+ }
+
+ const int64 batch_size = max_seg_id + 1;
+
+ // 获取线程池
+ auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int num_threads = worker_threads ? worker_threads->num_threads : 1;
+
+ if (input_dims == 1) {
+ // -------------------------------------------------------
+ // 1-D input: each segment reduces to a scalar
+ // output_shape: [batch_size]
+ // output_indices: [num_nonzero, 1] (segment_id)
+ // output_nonzero: [num_nonzero]
+ // -------------------------------------------------------
+ auto input_data = input_tensor.flat<float>();
+
+ Tensor* output_shape = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape({1}), &output_shape));
+ output_shape->flat<int32>()(0) = static_cast<int32>(batch_size);
+
+ // 优化:使用map管理索引 + 连续内存存储
+ absl::flat_hash_map<int64, int64> seg_id_to_idx;
+ seg_id_to_idx.reserve(num_indices / 4 + 1);
+
+ std::vector<float> segment_sums_data;
+ std::vector<int32_t> segment_counts_data;
+ std::vector<int64> segment_order;
+
+ if (is_mean_) {
+ // 聚合阶段 - 单线程(reduce维度)
+ for (int64 i = 0; i < num_indices; ++i) {
+ const int64 seg_id = segment_ids[i];
+ const Tidx data_row = indices_vec(i);
+
+ auto it = seg_id_to_idx.find(seg_id);
+ int64 idx;
+ if (it == seg_id_to_idx.end()) {
+ idx = segment_order.size();
+ seg_id_to_idx.emplace(seg_id, idx);
+ segment_order.push_back(seg_id);
+ segment_sums_data.push_back(0.0f);
+ segment_counts_data.push_back(0);
+ } else {
+ idx = it->second;
+ }
+
+ segment_sums_data[idx] += input_data(data_row);
+ segment_counts_data[idx]++;
+ }
+
+ // 预计算逆元
+ std::vector<float> inv_counts(segment_order.size());
+ for (size_t s = 0; s < segment_order.size(); ++s) {
+ inv_counts[s] = (segment_counts_data[s] > 0) ?
+ (1.0f / static_cast<float>(segment_counts_data[s])) : 0.0f;
+ }
+
+ // 统计非零数量
+ int64 num_nonzero = 0;
+ for (size_t s = 0; s < segment_order.size(); ++s) {
+ float val = segment_sums_data[s] * inv_counts[s];
+ if (val != 0.0f) num_nonzero++;
+ }
+
+ Tensor* output_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, TensorShape({num_nonzero, 1}),
+ &output_indices));
+ auto output_indices_data = output_indices->flat<int32>();
+
+ Tensor* output_nonzero = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, TensorShape({num_nonzero}),
+ &output_nonzero));
+ auto output_nonzero_data = output_nonzero->flat<float>();
+
+ // 直接填充输出 - 无中间容器
+ int64 idx = 0;
+ for (size_t s = 0; s < segment_order.size(); ++s) {
+ float val = segment_sums_data[s] * inv_counts[s];
+ if (val != 0.0f) {
+ output_indices_data(idx) = static_cast<int32>(segment_order[s]);
+ output_nonzero_data(idx) = val;
+ idx++;
+ }
+ }
+ } else {
+ // Sum模式
+ for (int64 i = 0; i < num_indices; ++i) {
+ const int64 seg_id = segment_ids[i];
+ const Tidx data_row = indices_vec(i);
+
+ auto it = seg_id_to_idx.find(seg_id);
+ int64 idx;
+ if (it == seg_id_to_idx.end()) {
+ idx = segment_order.size();
+ seg_id_to_idx.emplace(seg_id, idx);
+ segment_order.push_back(seg_id);
+ segment_sums_data.push_back(0.0f);
+ } else {
+ idx = it->second;
+ }
+
+ segment_sums_data[idx] += input_data(data_row);
+ }
+
+ // 统计非零数量
+ int64 num_nonzero = 0;
+ for (size_t s = 0; s < segment_order.size(); ++s) {
+ if (segment_sums_data[s] != 0.0f) num_nonzero++;
+ }
+
+ Tensor* output_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, TensorShape({num_nonzero, 1}),
+ &output_indices));
+ auto output_indices_data = output_indices->flat<int32>();
+
+ Tensor* output_nonzero = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, TensorShape({num_nonzero}),
+ &output_nonzero));
+ auto output_nonzero_data = output_nonzero->flat<float>();
+
+ // 直接填充输出
+ int64 idx = 0;
+ for (size_t s = 0; s < segment_order.size(); ++s) {
+ float val = segment_sums_data[s];
+ if (val != 0.0f) {
+ output_indices_data(idx) = static_cast<int32>(segment_order[s]);
+ output_nonzero_data(idx) = val;
+ idx++;
+ }
+ }
+ }
+
+ } else {
+ // -------------------------------------------------------
+ // 2-D input: each segment reduces to a vector of size embed_dim
+ // output_shape: [batch_size, embed_dim]
+ // output_indices: [num_nonzero, 2] (segment_id, dim_index)
+ // output_nonzero: [num_nonzero]
+ // -------------------------------------------------------
+ const int64 embed_dim = input_tensor.dim_size(1);
+ auto input_data = input_tensor.matrix<float>();
+
+ Tensor* output_shape = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape({2}), &output_shape));
+ output_shape->flat<int32>()(0) = static_cast<int32>(batch_size);
+ output_shape->flat<int32>()(1) = static_cast<int32>(embed_dim);
+
+ // 优化:map管理索引 + 连续内存存储数据
+ absl::flat_hash_map<int64, int64> seg_id_to_idx;
+ seg_id_to_idx.reserve(num_indices / 4 + 1);
+
+ std::vector<float> segment_sums_data;
+ std::vector<int32_t> segment_counts_data;
+ std::vector<int64> segment_order;
+
+ // ========== 阶段1:单线程聚合(Reduce维度,避免数据竞争)==========
+ for (int64 i = 0; i < num_indices; ++i) {
+ const int64 seg_id = segment_ids[i];
+ const Tidx data_row = indices_vec(i);
+
+ auto it = seg_id_to_idx.find(seg_id);
+ int64 idx;
+ if (it == seg_id_to_idx.end()) {
+ idx = segment_order.size();
+ seg_id_to_idx.emplace(seg_id, idx);
+ segment_order.push_back(seg_id);
+ segment_sums_data.resize((idx + 1) * embed_dim, 0.0f);
+ if (is_mean_) segment_counts_data.push_back(0);
+ } else {
+ idx = it->second;
+ }
+
+ float* sum_ptr = &segment_sums_data[idx * embed_dim];
+
+ // SIMD 加速聚合
+ int64 d = 0;
+ for (; d + 4 <= embed_dim; d += 4) {
+ float32x4_t input_vec = vld1q_f32(&input_data(data_row, d));
+ float32x4_t sum_vec = vld1q_f32(&sum_ptr[d]);
+ sum_vec = vaddq_f32(sum_vec, input_vec);
+ vst1q_f32(&sum_ptr[d], sum_vec);
+ }
+ // 处理剩余元素
+ for (; d < embed_dim; ++d) {
+ sum_ptr[d] += input_data(data_row, d);
+ }
+
+ if (is_mean_) {
+ segment_counts_data[idx]++;
+ }
+ }
+
+ const int64 num_unique_segments = segment_order.size();
+
+ // 预计算逆元
+ std::vector<float> inv_counts(num_unique_segments);
+ if (is_mean_) {
+ for (int64 s = 0; s < num_unique_segments; ++s) {
+ inv_counts[s] = (segment_counts_data[s] > 0) ?
+ (1.0f / static_cast<float>(segment_counts_data[s])) : 0.0f;
+ }
+ }
+
+ // ========== 阶段2:多线程统计非零数量(非Reduce维度)==========
+ std::vector<int64> segment_nz_counts(num_unique_segments, 0);
+
+ if (num_unique_segments >= 16 && num_threads > 1) {
+ auto count_work = [&](int64 start_seg, int64 end_seg) {
+ for (int64 s = start_seg; s < end_seg; ++s) {
+ float* sum_ptr = &segment_sums_data[s * embed_dim];
+ float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+ int64 count = 0;
+ for (int64 d = 0; d < embed_dim; ++d) {
+ if (sum_ptr[d] * inv_count != 0.0f) count++;
+ }
+ segment_nz_counts[s] = count;
+ }
+ };
+
+ Shard(num_threads, worker_threads->workers, num_unique_segments,
+ num_unique_segments * embed_dim / num_threads, count_work);
+ } else {
+ for (int64 s = 0; s < num_unique_segments; ++s) {
+ float* sum_ptr = &segment_sums_data[s * embed_dim];
+ float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+ int64 count = 0;
+ for (int64 d = 0; d < embed_dim; ++d) {
+ if (sum_ptr[d] * inv_count != 0.0f) count++;
+ }
+ segment_nz_counts[s] = count;
+ }
+ }
+
+ // 计算前缀和
+ std::vector<int64> segment_offsets(num_unique_segments + 1, 0);
+ for (int64 s = 0; s < num_unique_segments; ++s) {
+ segment_offsets[s + 1] = segment_offsets[s] + segment_nz_counts[s];
+ }
+ int64 total_nz = segment_offsets[num_unique_segments];
+
+ // 分配输出 Tensor
+ Tensor* output_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, TensorShape({total_nz, 2}),
+ &output_indices));
+ auto output_indices_data = output_indices->matrix<int32>();
+
+ Tensor* output_nonzero = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, TensorShape({total_nz}),
+ &output_nonzero));
+ auto output_nonzero_data = output_nonzero->flat<float>();
+
+ // ========== 阶段3:多线程填充输出(非Reduce维度)==========
+ if (num_unique_segments >= 16 && num_threads > 1) {
+ auto output_work = [&](int64 start_seg, int64 end_seg) {
+ for (int64 s = start_seg; s < end_seg; ++s) {
+ int64 seg_id = segment_order[s];
+ float* sum_ptr = &segment_sums_data[s * embed_dim];
+ float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+ int64 out_idx = segment_offsets[s];
+
+ for (int64 d = 0; d < embed_dim; ++d) {
+ float val = sum_ptr[d] * inv_count;
+ if (val != 0.0f) {
+ output_indices_data(out_idx, 0) = static_cast<int32>(seg_id);
+ output_indices_data(out_idx, 1) = static_cast<int32>(d);
+ output_nonzero_data(out_idx) = val;
+ out_idx++;
+ }
+ }
+ }
+ };
+
+ Shard(num_threads, worker_threads->workers, num_unique_segments,
+ num_unique_segments * embed_dim / num_threads, output_work);
+ } else {
+ int64 idx = 0;
+ for (int64 s = 0; s < num_unique_segments; ++s) {
+ int64 seg_id = segment_order[s];
+ float* sum_ptr = &segment_sums_data[s * embed_dim];
+ float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+
+ for (int64 d = 0; d < embed_dim; ++d) {
+ float val = sum_ptr[d] * inv_count;
+ if (val != 0.0f) {
+ output_indices_data(idx, 0) = static_cast<int32>(seg_id);
+ output_indices_data(idx, 1) = static_cast<int32>(d);
+ output_nonzero_data(idx) = val;
+ idx++;
+ }
+ }
+ }
+ }
+ }
+
+ }
+
+ private:
+ bool is_mean_;
+};
+
+#define REGISTER_KERNEL(Tidx) \
+ REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduceNonzero") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<Tidx>("Tidx"), \
+ KPFusedSparseSegmentReduceNonzeroOp<Tidx>);
+REGISTER_KERNEL(int64)
+REGISTER_KERNEL(int32)
+#undef REGISTER_KERNEL
new file mode 100644
@@ -0,0 +1,183 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseSegmentReduceNonzeroOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int combiner_mode) {
+ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce_nonzero",
+ "KPFusedSparseSegmentReduceNonzero")
+ .Input(FakeInput(DT_FLOAT)) // data
+ .Input(FakeInput(DT_INT32)) // indices
+ .Input(FakeInput(DT_INT64)) // slice_input
+ .Input(FakeInput(DT_INT32)) // begin
+ .Attr("combiner", combiner_mode)
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceMean) {
+ MakeOp(1);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int32>(&expected, {4});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); // output_shape
+
+ Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
+ test::FillValues<int32>(&expected_1, {2, 3});
+ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1)); // output_indices
+
+ Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&expected_2, {2, 2});
+ test::ExpectTensorEqual<float>(expected_2, *GetOutput(2)); // output_nonzero
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceSum) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int32>(&expected, {4});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); // output_shape
+
+ Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
+ test::FillValues<int32>(&expected_1, {2, 3});
+ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1)); // output_indices
+
+ Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&expected_2, {4, 2});
+ test::ExpectTensorEqual<float>(expected_2, *GetOutput(2)); // output_nonzero
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidData) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({2, 2, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input data must be a 1-D vector or 2-D matrix") !=
+ std::string::npos);
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidSliceinput) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4, 1}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D") !=
+ std::string::npos);
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidbegin) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestColsOutOfBounds) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 4});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestIndicesOutOfBounds) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({8}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),
+ "indices and slice_input.dim_size(0) should have same size"));
+}
+
+} // namespace
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,205 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseSegmentReduceOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int combiner_mode) {
+ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce",
+ "KPFusedSparseSegmentReduce")
+ .Input(FakeInput(DT_FLOAT)) // data
+ .Input(FakeInput(DT_INT32)) // indices
+ .Input(FakeInput(DT_INT64)) // slice_input
+ .Input(FakeInput(DT_INT32)) // begin
+ .Input(FakeInput(DT_INT32)) // begin_1
+ .Attr("combiner", combiner_mode)
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceMean) {
+ MakeOp(1);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({1}), {1});
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+ test::FillValues<float>(&expected,
+ {0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 4.0f, 3.0f, 4.0f});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+
+ Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected_1, {2});
+ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceSum) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+ test::FillValues<float>(&expected,
+ {0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 8.0f, 3.0f, 4.0f});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+
+ Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
+ test::FillValues<int32>(&expected_1, {4});
+ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestColsOutOfBounds) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 5});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, Test) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({2}),
+ {0, 2}); // num_indices != slice_input.dim_size(0)
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(),
+ "indices and slice_input.dim_zie(0) should have same size"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidData) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(
+ TensorShape({4, 2, 1}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); // data.dims() > 2
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "input must be 2-D"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidSliceinput) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(
+ TensorShape({3, 4, 1}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); // slice_input.dims() > 2
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({3}),
+ {0, 2, 1}); // begin has 3 elements
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin1) {
+ MakeOp(0);
+
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+ AddInputFromArray<int64>(TensorShape({3, 4}),
+ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+ AddInputFromArray<int32>(TensorShape({2}), {0, 1}); // begin_1 has 2 elements
+
+ Status s = RunOpKernel();
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "begin_1 must have 1 element"));
+}
+
+} // namespace
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,111 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include <algorithm>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/logging.h"
+
+using namespace tensorflow;
+
+class KPFusedSparseSelect : public OpKernel {
+public:
+ explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) {
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_a = context->input(0);
+ const Tensor& input_b = context->input(1);
+ const Tensor& input_c = context->input(2);
+ const Tensor& greater = context->input(3);
+ const Tensor& equal1 = context->input(4);
+ const Tensor& equal2 = context->input(5);
+ const Tensor& equal3 = context->input(6);
+
+ int32_t equal1_val = equal1.flat<int32_t>()(0);
+ int32_t equal2_val = equal2.flat<int32_t>()(0);
+ int32_t equal3_val = equal3.flat<int32_t>()(0);
+ VLOG(1) << "equal1_val: " << equal1_val;
+ VLOG(1) << "equal2_val: " << equal2_val;
+ VLOG(1) << "equal3_val: " << equal3_val;
+
+ int32_t greater_val = greater.flat<int32_t>()(0);
+ auto a_flat = input_a.flat<int32_t>();
+ auto b_flat = input_b.flat<int32_t>();
+ auto c_flat = input_c.flat<int32_t>();
+ VLOG(1) << "input_a shape: " << input_a.shape().DebugString();
+ VLOG(1) << "input_b shape: " << input_b.shape().DebugString();
+ VLOG(1) << "input_c shape: " << input_c.shape().DebugString();
+ OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(),
+ errors::InvalidArgument("Input num elements of a and b must match"));
+ OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(),
+ errors::InvalidArgument("Input num elements of a and c must match"));
+ auto N = input_a.NumElements();
+
+ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> a_reshaped_tensor(a_flat.data(), N, 1);
+ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> b_reshaped_tensor(b_flat.data(), N, 1);
+ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> c_reshaped_tensor(c_flat.data(), N, 1);
+
+ Tensor* output_x = nullptr;
+ Tensor* output_y = nullptr;
+ Tensor* output_w = nullptr;
+
+ OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x));
+ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y));
+ OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w));
+
+ Eigen::TensorMap<Eigen::Tensor<int32_t, 2, Eigen::RowMajor>> out_x(
+ output_x->flat<int32_t>().data(),
+ output_x->dim_size(0),
+ output_x->dim_size(1)
+ );
+
+ Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_y(
+ output_y->flat<float>().data(),
+ output_y->dim_size(0),
+ output_y->dim_size(1)
+ );
+
+ Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_w(
+ output_w->flat<float>().data(),
+ output_w->dim_size(0),
+ output_w->dim_size(1)
+ );
+
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ const int64 cost_per_unit = std::max(N / worker_threads->num_threads, int64(10));
+
+ auto work = [&](int64 start, int64 end) {
+ for (int64 i = start; i < end; i++) {
+ // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f
+ float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f;
+ float res_equal1 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f
+ float res_equal2 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : res_equal1; // Fill.2409-->1.0f
+ out_x(i, 0) = a_reshaped_tensor(i, 0); // Reshape.2401
+ out_y(i, 0) = res_equal2;
+ out_w(i, 0) = res_equal2; // Mul.2419 硬编码 1.0f * input
+ out_w(i, 1) = 1.0f; // select_2427被消除,直接使用Fill.2422-->1.0f
+ }
+ };
+ Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU),
+ KPFusedSparseSelect);
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,182 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::ExpectClose;
+using tensorflow::test::FillValues;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedSparseSelectTest : public OpsTestBase {
+ protected:
+ void RunValidCase(
+ const TensorShape& shape,
+ const std::vector<int32>& a_data,
+ const std::vector<int32>& b_data,
+ const std::vector<int32>& c_data,
+ int32_t greater_val,
+ int32_t equal1_val,
+ int32_t equal2_val,
+ const std::vector<float>& expected_y,
+ const std::vector<float>& expected_w_col0) {
+
+ TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32)) // greater
+ .Input(FakeInput(DT_INT32)) // equal1
+ .Input(FakeInput(DT_INT32)) // equal2
+ .Input(FakeInput(DT_INT32)) // equal3
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+
+ AddInputFromArray<int32>(shape, a_data);
+ AddInputFromArray<int32>(shape, b_data);
+ AddInputFromArray<int32>(shape, c_data);
+ AddInputFromArray<int32>(TensorShape({}), {greater_val}); // scalar
+ AddInputFromArray<int32>(TensorShape({}), {equal1_val});
+ AddInputFromArray<int32>(TensorShape({}), {equal2_val});
+ AddInputFromArray<int32>(TensorShape({}), {0}); // equal3_val (未使用)
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ const Tensor& out_x = *GetOutput(0);
+ const Tensor& out_y = *GetOutput(1);
+ const Tensor& out_w = *GetOutput(2);
+
+ int32 Num_elements = expected_y.size();
+ // 验证 output_x: 就是 input_a
+ std::vector<int32_t> a_data_int(a_data.begin(), a_data.end());
+ ExpectTensorEqual<int32>(out_x, AsTensor<int32>(a_data_int, {Num_elements, 1}));
+
+ // 验证 output_y
+ ExpectTensorEqual<float>(out_y, AsTensor<float>(expected_y, {Num_elements, 1}));
+ // 验证 output_w 第一列
+ auto w_mat = out_w.matrix<float>();
+ for (int i = 0; i < w_mat.dimension(0); ++i) {
+ EXPECT_FLOAT_EQ(w_mat(i, 0), expected_w_col0[i]);
+ EXPECT_FLOAT_EQ(w_mat(i, 1), 1.0f); // 第二列必须是 1.0
+ }
+ }
+
+ Status RunOpExpectFailure(
+ const TensorShape& shape,
+ const std::vector<int32>& a_data,
+ const std::vector<int32>& b_data,
+ const std::vector<int32>& c_data,
+ int32_t greater_val,
+ int32_t equal1_val,
+ int32_t equal2_val) {
+
+ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+ TensorShape b_shape({static_cast<int64>(b_data.size())});
+ TensorShape c_shape({static_cast<int64>(c_data.size())});
+ AddInputFromArray<int32>(shape, a_data);
+ AddInputFromArray<int32>(b_shape, b_data);
+ AddInputFromArray<int32>(c_shape, c_data);
+ AddInputFromArray<int32>(TensorShape({}), {greater_val});
+ AddInputFromArray<int32>(TensorShape({}), {equal1_val});
+ AddInputFromArray<int32>(TensorShape({}), {equal2_val});
+ AddInputFromArray<int32>(TensorShape({}), {0});
+
+ return RunOpKernel();
+ }
+};
+
+// ==================== 正向测试 ====================
+// 更多正向验证参考 fused_embedding_sparse_select_test.py
+TEST_F(KPFusedSparseSelectTest, Valid_NormalInput) {
+ RunValidCase(
+ TensorShape({3}), // shape
+ {5, 3, 8}, // input_a
+ {1, 2, 1}, // input_b
+ {9, 8, 7}, // input_c (未使用)
+ 4, // greater_val
+ 1, // equal1_val
+ 3, // equal2_val
+ {1.0f, 0.0f, 1.0f}, // expected_y
+ {1.0f, 0.0f, 1.0f} // expected_w_col0
+ );
+}
+
+TEST_F(KPFusedSparseSelectTest, Valid_2DInput) {
+ RunValidCase(
+ TensorShape({2, 2}),
+ {6, 3, 8, 2},
+ {2, 1, 3, 4},
+ {0, 0, 0, 0},
+ 5,
+ 2,
+ 3,
+ {1.0f, 0.0f, 1.0f, 0.0f},
+ {1.0f, 0.0f, 1.0f, 0.0f}
+ );
+}
+// ==================== 反向测试 ====================
+// 反例1:input_a 与 input_b 元素数不匹配
+TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AB) {
+ Status s = RunOpExpectFailure(
+ TensorShape({3}), // a 有 3 个元素
+ {1, 2, 3},
+ {4, 5}, // b 有 2 个元素 → 不匹配!
+ {6, 7, 8},
+ 0, 1, 2
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and b must match"));
+}
+
+// 反例2:input_a 与 input_c 元素数不匹配
+TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AC) {
+ Status s = RunOpExpectFailure(
+ TensorShape({2}),
+ {1, 2},
+ {3, 4},
+ {5}, // c 只有 1 个元素 → 不匹配!
+ 0, 1, 2
+ );
+ EXPECT_FALSE(s.ok());
+ EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and c must match"));
+}
+
+}
\ No newline at end of file
@@ -62,6 +62,7 @@ tf_gen_op_libs(
"decode_proto_ops",
"encode_proto_ops",
"experimental_dataset_ops",
+ "embedding_fused_ops",
"filesystem_ops",
"function_ops",
"functional_ops",
@@ -357,7 +358,12 @@ cc_library(
":ktfop_ops_op_lib",
]) + if_fused_embedding([
":fused_embedding_ops_op_lib",
- ]),
+ ]) + select({
+ "@platforms//cpu:aarch64": [
+ ":embedding_fused_ops_op_lib"
+ ],
+ "//conditions:default": [],
+ }),
alwayslink = 1,
)
new file mode 100644
@@ -0,0 +1,134 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPFusedSparseSegmentReduce")
+ .Input("data: float")
+ .Input("indices: Tidx")
+ .Input("slice_input: int64")
+ .Input("begin: int32")
+ .Input("begin_1: int32")
+ .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Output("output: float")
+ .Output("slice_output: int32")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseSegmentReduceNonzero")
+ .Input("data: float")
+ .Input("indices: Tidx")
+ .Input("slice_input: int64")
+ .Input("begin: int32")
+ .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Output("output_shape: int32")
+ .Output("output_indices: int32")
+ .Output("output_nonzero: float")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedEmbeddingPaddingFast")
+ .Input("input0: int64")
+ .Input("input1: float")
+ .Input("input2: int32")
+ .Input("input3: int32")
+ .Input("pack: int32")
+ .Output("output0: int32")
+ .Output("output1: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle scalar_shape = c->Scalar();
+ c->set_output(0, scalar_shape);
+ c->set_output(1, scalar_shape);
+ return OkStatus();
+ });
+
+REGISTER_OP("KPFusedEmbeddingPadding")
+ .Input("input0: int64")
+ .Input("input1: float")
+ .Input("input2: int32")
+ .Input("input3: int32")
+ .Input("pack: int32")
+ .Output("output0: int32")
+ .Output("output1: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle out;
+ ShapeHandle scalar_shape = c->Scalar();
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
+ c->set_output(0, scalar_shape);
+ c->set_output(1, out);
+ return OkStatus();
+ });
+
+REGISTER_OP("KPFusedSparseSelect")
+ .Input("input_a: int32")
+ .Input("input_b: int32")
+ .Input("input_c: int32")
+ .Input("greater: int32")
+ .Input("equal1: int32")
+ .Input("equal2: int32")
+ .Input("equal3: int32")
+ .Output("output_x: int32")
+ .Output("output_y: float")
+ .Output("output_w: float")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseReshape")
+ .Input("slice_input: int64")
+ .Input("begin: int32")
+ .Input("new_shape: int64")
+ .Input("pack_const: T")
+ .Output("out_indices: int64")
+ .Output("out_shape: int64")
+ .Attr("T: {int32, int64}")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseDynamicStitch")
+ .Input("x: int64")
+ .Input("variables: N * float")
+ .Output("output: float")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedGather")
+ .Input("data: float")
+ .Input("slice_input: int64")
+ .Input("begin: int32")
+ .Output("out_shape: int64")
+ .Output("out_indices: int32")
+ .Output("out_data: float")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedEmbeddingActionIdGather")
+ .Input("input0: Tindices1")
+ .Input("input1: float")
+ .Input("input2: Tindices2")
+ .Input("input3: int32")
+ .Input("pack: int32")
+ .Attr("Tindices1: {int32, int64} = DT_INT64")
+ .Attr("Tindices2: {int32, int64} = DT_INT32")
+ .Output("output0: float")
+ .SetShapeFn(shape_inference::UnknownShape);
+} // namespace tensorflow
\ No newline at end of file
@@ -314,7 +314,12 @@ py_strict_library(
"//tensorflow/python/util:tf_decorator_export",
"//tensorflow/python/util:tf_export",
"//third_party/py/numpy",
- ],
+ ] + select({
+ "@platforms//cpu:aarch64": [
+ "//tensorflow/python/ops:embedding_fused_ops_gen",
+ ],
+ "//conditions:default": [],
+ }),
)
py_strict_library(
@@ -450,7 +455,12 @@ py_strict_library(
"//tensorflow/python/util:tf_decorator",
"//tensorflow/python/util:tf_decorator_export",
"//tensorflow/python/util:tf_export",
- ],
+ ] + select({
+ "@platforms//cpu:aarch64": [
+ "//tensorflow/python/ops:embedding_fused_ops_gen",
+ ],
+ "//conditions:default": [],
+ }),
)
# Necessary for the pywrap inclusion below.
new file mode 100644
@@ -0,0 +1,204 @@
+import os
+import time
+import shutil
+import glob
+import json
+import sys
+import re
+import struct
+import tensorflow.compat.v1 as tf
+from dataclasses import dataclass, field
+from typing import List, Callable, Dict, Any, Union, Optional
+from tensorflow.python.client import timeline
+import numpy as np
+import multiprocessing as mp
+import pickle
+from tensorflow.python.client import timeline
+from tensorflow.core.protobuf import rewriter_config_pb2
+
+tf.disable_eager_execution()
+
+
+@dataclass
+class TestCase:
+ name: str
+ op_fn: Callable
+ input_fn: Callable
+ check_fn: Optional[Callable] = None
+ fused_op_name: str = ""
+ start_op_name: str = ""
+ end_op_name: str = ""
+ is_fused: bool = True
+ num_iters: int = 0
+ optimize_percent: int = 0
+ meta: Dict = field(default_factory=dict)
+
+class CheckFuncClass:
+ def check_fn_default(A, B, meta):
+ if (len(A) != len(B)):
+ return False
+ is_check_OK = True
+ for i in range(len(A)):
+ if (A[i].dtype == tf.float32):
+ is_check_OK &= np.testing.assert_allclose(A[i], B[i], rtol=1e-3, atol=1e-4) == None
+ else:
+ #整数类型严格判断相等, 当前浮点类型仅有float32
+ np.testing.assert_array_equal(A[i], B[i]) # 如果失败抛出异常, 如果成功无返回值,
+
+ return is_check_OK
+
+class UniversalOpBenchmark:
+ def __init__(self, log_dir="timeline"):
+ self.log_dir = log_dir
+ self.results = []
+
+ def generate_timeline(self, step_stats_list, timeline_file):
+ if not os.path.exists("timeline"):
+ os.makedirs("timeline")
+ ctf_list = []
+ for step_stats in step_stats_list:
+ tl = timeline.Timeline(step_stats)
+ ctf_list.append(json.loads(tl.generate_chrome_trace_format()))
+ with open(f"timeline/{timeline_file}", "w") as f:
+ json.dump(ctf_list, f, indent=2)
+
+ def is_fused_op_exist(self, timeline_file, op_name):
+ """从 timeline JSON 文件中检查指定算子(fusedOp)是否存在"""
+ with open(f"timeline/{timeline_file}", "r") as f:
+ trace_events = json.load(f)[0]["traceEvents"] # timeline.json的格式
+ op_exists = any(e.get("name") == op_name for e in trace_events if "dur" in e)
+ return op_exists
+
+ def extract_op_dur(self, timeline_file, op_name, times=1):
+ """从 timeline JSON 文件中提取指定算子(fusedOp)的平均耗时(μs)"""
+ with open(f"timeline/{timeline_file}", "r") as f:
+ trace_events_list = json.load(f) # timeline.json的格式
+ durations_list = []
+ for trace_events in trace_events_list:
+ durations = [e["dur"] for e in trace_events["traceEvents"] if e.get("name") == op_name and "dur" in e]
+ durations_list.append(durations[0])
+ if len(durations_list) != times:
+ raise ValueError(f"Expected {times} durations for {op_name}, but got {len(durations)}")
+ return np.mean(durations_list)
+
+
+ def extract_op_total_time(self, timeline_file, start_op, end_op, times):
+ """计算从 start_op 到 end_op 的总耗时的平均值(包含调度空隙)"""
+ with open(f"timeline/{timeline_file}", "r") as f:
+ trace_events_list = json.load(f)
+ time_list = []
+ for trace_events in trace_events_list:
+ start_event = next(e for e in trace_events["traceEvents"] if e.get("args", {}).get("name") == start_op)
+ end_event = next(e for e in trace_events["traceEvents"] if e.get("args", {}).get("name") == end_op)
+ start_time = start_event["ts"]
+ end_time = end_event["ts"] + end_event["dur"] # ts 是开始时间,dur是算子的持续时间
+ total_time = end_time - start_time
+ time_list.append(total_time)
+ if len(time_list) != times:
+ raise ValueError(f"Expected {times} total times for {start_op} to {end_op}, but got {len(time_list)}")
+ return np.mean(time_list)
+
+ def parse_performance_time(self, embedding_fused_enable, raw_inputs, test_case: TestCase):
+ print(f"Testing: {test_case.name} ...")
+ os.environ["ANNC_FUSED_ALL"] = str(embedding_fused_enable)
+
+ res_node, feed_dict = test_case.op_fn(raw_inputs, test_case.meta)
+ config = tf.compat.v1.ConfigProto()
+ config.inter_op_parallelism_threads = 16
+ config.intra_op_parallelism_threads = 16
+
+ run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
+ run_metadata = tf.compat.v1.RunMetadata()
+ all_step_stats = []
+ filename = f"{test_case.name}_performance_{embedding_fused_enable}.timeline.json"
+ if embedding_fused_enable != 0:
+ # 开启融合条件下关闭部分开关防止中间进行部分非预期优化
+ config.graph_options.rewrite_options.constant_folding = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.rewrite_options.arithmetic_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.rewrite_options.remapping = rewriter_config_pb2.RewriterConfig.AGGRESSIVE
+ with tf.Session(config=config) as sess:
+ time.sleep(1)
+ for _ in range(10):
+ sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+ for _ in range(test_case.num_iters):
+ sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+ all_step_stats.append(run_metadata.step_stats)
+ self.generate_timeline(all_step_stats, filename)
+ if embedding_fused_enable != 0:
+ return self.extract_op_dur(filename, test_case.fused_op_name, test_case.num_iters)
+ else:
+ return self.extract_op_total_time(filename, test_case.start_op_name, test_case.end_op_name, test_case.num_iters)
+
+ def run_performance_test(self, test_case: TestCase):
+ if test_case.num_iters == 0:
+ return True
+
+ print(f"Performance testing: {test_case.name} ...")
+
+ optimize_percent = test_case.optimize_percent
+ operator_name = test_case.fused_op_name
+ raw_inputs = test_case.input_fn(test_case.meta)
+
+ no_fused_time = self.parse_performance_time(0, raw_inputs, test_case)
+ fused_time = self.parse_performance_time(1, raw_inputs, test_case)
+ if fused_time <= 0.0 or no_fused_time <= 0.0:
+ print(f"⚠️ 获取平均运行时长失败, 无法进行比较")
+ return False
+ real_percent = 100 * (no_fused_time/fused_time - 1)
+ if real_percent < optimize_percent:
+ raise Exception(f"⚠️ 性能提升{real_percent:.2f}%, 低于预期{optimize_percent}%, embedding算子融合开启状态下耗时{fused_time:.2f}us, 关闭状态下耗时{no_fused_time:.2f}us")
+ return False
+
+ print(f"性能测试通过, 性能提升{real_percent:.2f}%, embedding算子融合开启状态下耗时{fused_time:.2f}us, 关闭状态下耗时{no_fused_time:.2f}us")
+ return True
+
+ def run_function_test(self, test_case: TestCase):
+ print(f"Function testing: {test_case.name} ...")
+
+ raw_inputs = test_case.input_fn(test_case.meta)
+
+ def execute_variant(enable_embedding_fused, raw_inputs, test_case: TestCase):
+ os.environ["ANNC_FUSED_ALL"] = str(enable_embedding_fused)
+ tf.reset_default_graph()
+
+ placeholders = []
+ res_node, feed_dict = test_case.op_fn(raw_inputs, test_case.meta)
+ config = tf.ConfigProto(
+ inter_op_parallelism_threads=16,
+ intra_op_parallelism_threads=16
+ )
+
+ config.graph_options.rewrite_options.constant_folding = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.rewrite_options.arithmetic_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.rewrite_options.remapping = rewriter_config_pb2.RewriterConfig.AGGRESSIVE
+ with tf.Session(config=config) as sess:
+ is_fused = False
+ if enable_embedding_fused == 0:
+ result = sess.run(res_node, feed_dict=feed_dict)
+ else:
+ run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
+ run_metadata = tf.compat.v1.RunMetadata()
+ filename = f"{test_case.name}_func_{enable_embedding_fused}.timeline.json"
+ result = sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+ self.generate_timeline([run_metadata.step_stats], filename)
+ is_fused = self.is_fused_op_exist(filename, test_case.fused_op_name)
+ return result, is_fused
+ return
+
+ embedding_data, is_fused = execute_variant(1, raw_inputs, test_case)
+ if is_fused != test_case.is_fused:
+ print(f"⚠️ {test_case.fused_op_name}算子未做融合")
+ raise Exception(f"⚠️ {test_case.fused_op_name}算子未做融合")
+ return False
+ no_embedding_data, _ = execute_variant(0, raw_inputs, test_case)
+ if test_case.check_fn is None:
+ is_correct = CheckFuncClass.check_fn_default(no_embedding_data, embedding_data, test_case.meta)
+ else:
+ is_correct = test_case.check_fn(no_embedding_data, embedding_data, test_case.meta)
+
+ if not is_correct:
+ print(f"⚠️ 误差较大")
+ raise Exception(f"⚠️ 误差较大")
+ return False
+ print("功能测试通过")
+ return True
new file mode 100644
@@ -0,0 +1,51 @@
+import argparse
+import importlib
+import pkgutil
+import traceback
+import os
+
+def main():
+ parser = argparse.ArgumentParser(description="TF Op Benchmark Framework with KDNN Control")
+ parser.add_argument('--op', type=str, help='指定要运行的算子模块名')
+ parser.add_argument('--list', action='store_true', help='列出所有模块')
+ parser.add_argument('--performance_test', choices=['True', 'False'], default='True', help='是否运行性能测试模块')
+
+ args = parser.parse_args()
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #日志太多, 关掉
+
+ import tensorflow as TF
+ from framework.runner import UniversalOpBenchmark
+ import ops
+
+ # 4. 扫描模块
+ available_ops = {
+ name: f'ops.{name}'
+ for loader, name, is_pkg in pkgutil.iter_modules(ops.__path__)
+ }
+
+ if args.list:
+ print("📁 可用算子列表:")
+ for name in available_ops: print(f" - {name}")
+ return
+
+ target_modules = [available_ops[args.op]] if args.op else list(available_ops.values())
+ root_log_dir = "bench_logs"
+ bench = UniversalOpBenchmark(log_dir=root_log_dir)
+
+ for module_path in target_modules:
+ try:
+ module = importlib.import_module(module_path)
+ if hasattr(module, 'get_test_cases'):
+ cases = module.get_test_cases()
+ for case in cases:
+ bench.run_function_test(case)
+ if args.performance_test == "True":
+ for case in cases:
+ bench.run_performance_test(case)
+ except Exception as e:
+ print(f"❌ 运行模块 {module_path} 失败: {e}")
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,150 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedEmbeddingActionIdGatherOp(input0, input1, input2, input3, pack):
+ gather1 = tf.gather(input1, input0, axis=0)
+ gather2 = tf.gather(gather1, input2, axis=0)
+ pack1 = tf.stack([input3, pack], axis=0)
+ pack2 = tf.stack([input3, -1], axis=0)
+ reshape = tf.reshape(gather2, pack2)
+ fill = tf.fill(pack1, tf.constant(0, dtype=tf.float32))
+ output = tf.concat([reshape, fill], axis=-1)
+ return output
+
+ def KPFusedEmbeddingActionIdGather_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ input0_type = np.int32
+ if "input0_type" in meta.keys():
+ input0_type = meta["input0_type"]
+ input2_type = np.int32
+ if "input2_type" in meta.keys():
+ input2_type = meta["input2_type"]
+ input0 = tf.compat.v1.placeholder(input0_type, shape=input['input0'].shape, name="input0")
+ input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+ input2 = tf.compat.v1.placeholder(input2_type, shape=input['input2'].shape, name="input2")
+ input3 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input3")
+ pack = tf.compat.v1.placeholder(tf.int32, shape=(), name="pack")
+
+ feed_dict = {
+ input0: input['input0'],
+ input1: input['input1'],
+ input2: input['input2'],
+ input3: input['input3'],
+ pack: input['pack']
+ }
+
+ result = KPFusedEmbeddingActionIdGatherOp(input0, input1, input2, input3, pack)
+ return result, feed_dict
+
+ def build_input_case_1(meta):
+ indices1_shape = (8, 10)
+ indices2_shape = (5, 6)
+ params_shape = (80, 300)
+ input = {}
+ input0_type = np.int32
+ if "input0_type" in meta.keys():
+ input0_type = meta["input0_type"]
+ input2_type = np.int32
+ if "input2_type" in meta.keys():
+ input2_type = meta["input2_type"]
+ # 根据核心函数入参名生成输入数据
+ input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(input0_type)
+ input["input1"] = np.random.random(params_shape).astype(np.float32)
+ input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(input2_type)
+ input["input3"] = params_shape[0]
+ input["pack"] = 1680
+ return input
+
+ def build_input_3D_case_1(meta):
+ indices1_shape = (8, 10, 2)
+ indices2_shape = (5, 6)
+ params_shape = (160, 300)
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(np.int32)
+ input["input1"] = np.random.random(params_shape).astype(np.float32)
+ input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(np.int32)
+ input["input3"] = params_shape[0]
+ input["pack"] = 1680
+ return input
+
+ def build_input_3D_case_2(meta):
+ indices1_shape = (8, 10)
+ indices2_shape = (5, 6, 2)
+ params_shape = (160, 300)
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(np.int32)
+ input["input1"] = np.random.random(params_shape).astype(np.float32)
+ input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(np.int32)
+ input["input3"] = params_shape[0]
+ input["pack"] = 1680
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_case_1_int32_int32",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_case_1,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ start_op_name = "PartitionedCall_1/stack_1",
+ end_op_name = "PartitionedCall_1/concat",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=50,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_case_1_int64_int32",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_case_1,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ start_op_name = "PartitionedCall_3/stack_1",
+ end_op_name = "PartitionedCall_3/concat",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=50,
+ meta = {"input0_type": np.int64},
+ ),
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_case_1_int32_int64",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_case_1,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ start_op_name = "PartitionedCall_5/stack_1",
+ end_op_name = "PartitionedCall_5/concat",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=50,
+ meta = {"input2_type": np.int64},
+ ),
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_case_1_int64_int64",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_case_1,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ start_op_name = "PartitionedCall_7/stack_1",
+ end_op_name = "PartitionedCall_7/concat",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=50,
+ meta = {"input0_type": np.int64, "input2_type": np.int64},
+ ),
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_3D_case_1",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_3D_case_1,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ is_fused=False,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingActionIdGather_3D_case_2",
+ op_fn = KPFusedEmbeddingActionIdGather_graph,
+ input_fn = build_input_3D_case_2,
+ fused_op_name = "KPFusedEmbeddingActionIdGather",
+ is_fused=False,
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,122 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedEmbeddingPaddingFastOp(input0, input1, input2, input3, pack, dims):
+ cast = tf.cast(input0, tf.int32)
+ begin = tf.constant([0], dtype=tf.int32)
+ end = tf.constant([1], dtype=tf.int32)
+ strides = tf.constant([1], dtype=tf.int32)
+ hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+ sub_out = hash_rows - input2
+ if dims == 3:
+ fill_shape = tf.stack([sub_out, pack, 5], axis=0)
+ fill = tf.fill(fill_shape, tf.constant(0, dtype=tf.float32))
+ else:
+ pack_op = tf.stack([sub_out, pack], axis=0)
+ fill = tf.fill(pack_op, tf.constant(0, dtype=tf.float32))
+ concat = tf.concat([input1, fill], 0)
+ reshape = tf.reshape(concat, input3)
+ shape_tensor = tf.shape(reshape)
+ output = tf.strided_slice(shape_tensor, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+ return output
+
+ def KPFusedEmbeddingPaddingFast_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0")
+ input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+ input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2")
+ input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3")
+ pack = tf.constant(input['pack'], dtype=tf.int32)
+
+ feed_dict = {
+ input0: input['input0'],
+ input1: input['input1'],
+ input2: input['input2'],
+ input3: input['input3'],
+ pack: input['pack']
+ }
+
+ result = KPFusedEmbeddingPaddingFastOp(input0, input1, input2, input3, pack, len(input['input1'].shape))
+ return [result], feed_dict
+
+ def build_input_func(input_shape, pooling_shape, reshape):
+ input = {}
+ input["input0"] = np.array(input_shape).astype(np.int64)
+ input["input1"] = np.random.rand(*pooling_shape).astype(np.float32)
+ input["input2"] = pooling_shape[0]
+ input["input3"] = np.array(reshape).astype(np.int32)
+ input["pack"] = pooling_shape[1]
+ return input
+
+ def build_input_2D_case_1(meta):
+ return build_input_func((151 * 1, 10), (151 * 1, 10), (-1, 1510))
+
+ def build_input_2D_case_2(meta):
+ return build_input_func((151 * 1000, 10), (151 * 10, 10), (-1, 1510))
+
+ def build_input_2D_case_3(meta):
+ return build_input_func((2 * 1, 12), (2 * 1, 12), (-1, 24))
+
+ def build_input_2D_case_4(meta):
+ return build_input_func((2 * 1000, 12), (2 * 10, 12), (-1, 24))
+
+ def build_input_3D_case_1(meta):
+ return build_input_func((2 * 1000, 12), (2 * 10, 12, 5), (-1, 24))
+
+ return [
+ TestCase(
+ name="KPFusedEmbeddingPaddingFast_case_1",
+ op_fn=KPFusedEmbeddingPaddingFast_graph,
+ input_fn=build_input_2D_case_1,
+ fused_op_name="KPFusedEmbeddingPaddingFast",
+ start_op_name="PartitionedCall_1/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_1/StridedSlice_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=600,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPaddingFast_case_2",
+ op_fn=KPFusedEmbeddingPaddingFast_graph,
+ input_fn=build_input_2D_case_2,
+ fused_op_name="KPFusedEmbeddingPaddingFast",
+ start_op_name="PartitionedCall_3/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_3/StridedSlice_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=7000,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPaddingFast_case_3",
+ op_fn=KPFusedEmbeddingPaddingFast_graph,
+ input_fn=build_input_2D_case_3,
+ fused_op_name="KPFusedEmbeddingPaddingFast",
+ start_op_name="PartitionedCall_5/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_5/StridedSlice_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=600,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPaddingFast_case_4",
+ op_fn=KPFusedEmbeddingPaddingFast_graph,
+ input_fn=build_input_2D_case_4,
+ fused_op_name="KPFusedEmbeddingPaddingFast",
+ start_op_name="PartitionedCall_7/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_7/StridedSlice_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=800,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPadding_3D",
+ op_fn=KPFusedEmbeddingPaddingFast_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedEmbeddingPaddingFast",
+ is_fused=False,
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,121 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedEmbeddingPaddingOp(input0, input1, input2, input3, pack, dims):
+ cast = tf.cast(input0, tf.int32)
+ begin = tf.constant([0], dtype=tf.int32)
+ end = tf.constant([1], dtype=tf.int32)
+ strides = tf.constant([1], dtype=tf.int32)
+ hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+ sub_out = hash_rows - input2
+
+ if dims == 3:
+ fill_shape = tf.stack([sub_out, pack, 5], axis=0)
+ fill = tf.fill(fill_shape, tf.constant(0, dtype=tf.float32))
+ else:
+ pack_op = tf.stack([sub_out, pack], axis=0)
+ fill = tf.fill(pack_op, tf.constant(0, dtype=tf.float32))
+ concat = tf.concat([input1, fill], 0)
+ output = tf.reshape(concat, input3)
+ return tf.concat([output, output], 1)
+
+ def KPFusedEmbeddingPadding_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0")
+ input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+ input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2")
+ input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3")
+ pack = tf.constant(input['pack'], dtype=tf.int32)
+
+ feed_dict = {
+ input0: input['input0'],
+ input1: input['input1'],
+ input2: input['input2'],
+ input3: input['input3'],
+ pack: input['pack']
+ }
+
+ result = KPFusedEmbeddingPaddingOp(input0, input1, input2, input3, pack, len(input['input1'].shape))
+ return [result], feed_dict
+
+ def build_input_func(input_shape, pooling_shape, reshape):
+ input = {}
+ input["input0"] = np.array(input_shape).astype(np.int64)
+ input["input1"] = np.random.rand(*pooling_shape).astype(np.float32)
+ input["input2"] = pooling_shape[0]
+ input["input3"] = np.array(reshape).astype(np.int32)
+ input["pack"] = pooling_shape[1]
+ return input
+
+ def build_input_2D_case_1(meta):
+ return build_input_func((151 * 1, 10), (151 * 1, 10), (-1, 1510))
+
+ def build_input_2D_case_2(meta):
+ return build_input_func((151 * 1000, 10), (151 * 10, 10), (-1, 1510))
+
+ def build_input_2D_case_3(meta):
+ return build_input_func((2 * 1, 12), (2 * 1, 12), (-1, 24))
+
+ def build_input_2D_case_4(meta):
+ return build_input_func((2 * 1000, 12), (2 * 10, 12), (-1, 24))
+
+ def build_input_3D_case_1(meta):
+ return build_input_func((2 * 1000, 12), (2 * 10, 12, 5), (-1, 24))
+
+ return [
+ TestCase(
+ name="KPFusedEmbeddingPadding_case_1",
+ op_fn=KPFusedEmbeddingPadding_graph,
+ input_fn=build_input_2D_case_1,
+ fused_op_name="KPFusedEmbeddingPadding",
+ start_op_name="PartitionedCall_1/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_1/concat_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=400,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPadding_case_2",
+ op_fn=KPFusedEmbeddingPadding_graph,
+ input_fn=build_input_2D_case_2,
+ fused_op_name="KPFusedEmbeddingPadding",
+ start_op_name="PartitionedCall_3/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_3/concat_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=100,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPadding_case_3",
+ op_fn=KPFusedEmbeddingPadding_graph,
+ input_fn=build_input_2D_case_3,
+ fused_op_name="KPFusedEmbeddingPadding",
+ start_op_name="PartitionedCall_5/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_5/concat_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=450,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPadding_case_4",
+ op_fn=KPFusedEmbeddingPadding_graph,
+ input_fn=build_input_2D_case_4,
+ fused_op_name="KPFusedEmbeddingPadding",
+ start_op_name="PartitionedCall_7/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+ end_op_name="PartitionedCall_7/concat_1",
+ is_fused=True,
+ num_iters=500,
+ optimize_percent=700,
+ ),
+ TestCase(
+ name="KPFusedEmbeddingPadding_3D",
+ op_fn=KPFusedEmbeddingPadding_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedEmbeddingPadding",
+ is_fused=False,
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,103 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedGatherOp(data, slice_input, begin, dims):
+ if dims == 3:
+ end=[slice_input.shape.as_list()[0], begin[1] + 2, begin[2] + 2]
+ strides=[1, 1, 1]
+ else:
+ end=[slice_input.shape.as_list()[0], begin[1] + 2]
+ strides=[1, 1]
+ slice_out = tf.strided_slice(
+ slice_input,
+ begin=begin,
+ end=end,
+ strides=strides,
+ begin_mask=1,
+ end_mask=1,
+ shrink_axis_mask=2
+ )
+
+ if dims == 3:
+ value, indices = tf.unique(slice_out[0])
+ else:
+ value, indices = tf.unique(slice_out)
+ value_1, indices_1 = tf.unique(value)
+ gather1 = tf.gather(data, value_1)
+ gather2 = tf.gather(gather1, indices_1)
+ return value, indices, gather2
+
+ def KPFusedGather_graph(input, meta):
+ data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+ slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+ begin = tf.compat.v1.placeholder(tf.int32, shape=input['begin'].shape, name="begin")
+
+ feed = {
+ data: input['data'],
+ slice_input: input['slice_input'],
+ begin: input['begin']
+ }
+ shape, indices, data = KPFusedGatherOp(data, slice_input, begin, len(input['slice_input'].shape))
+
+ return [shape, indices, data], feed
+
+ def KPFusedGather_check_fn(input_a, input_b, meta):
+ np.testing.assert_array_equal(input_a[0], input_b[0])
+ np.testing.assert_array_equal(input_a[1], input_b[1])
+ return np.allclose(input_a[2], input_b[2], rtol=1e-3, atol=1e-5)
+
+ def build_input_2D_case_1(meta):
+ input = {}
+ input["data"] = np.random.rand(50, 12).astype(np.float32)
+ input["slice_input"] = np.array([[10, 7], [20, 7], [30, 7]], dtype=np.int64)
+ input["begin"] = np.array([0, 1], dtype=np.int32)
+ return input
+
+ def build_input_3D_case_1(meta):
+ input = {}
+ input["data"] = np.random.rand(50, 12, 5).astype(np.float32)
+ input["slice_input"] = np.array([[10, 7], [20, 7], [30, 7]], dtype=np.int64)
+ input["begin"] = np.array([0, 1], dtype=np.int32)
+ return input
+
+ def build_input_3D_case_2(meta):
+ input = {}
+ input["data"] = np.random.rand(50, 12).astype(np.float32)
+ input["slice_input"] = np.array([[[10, 7], [20, 7], [30, 7]], [[10, 2], [20, 3], [30, 4]]], dtype=np.int64)
+ input["begin"] = np.array([0, 1, 1], dtype=np.int32)
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedGather_input_2D_case_1",
+ op_fn=KPFusedGather_graph,
+ input_fn=build_input_2D_case_1,
+ check_fn = KPFusedGather_check_fn,
+ fused_op_name = "KPFusedGather",
+ start_op_name = "PartitionedCall_1/strided_slice",
+ end_op_name = "PartitionedCall_1/GatherV2",
+ is_fused = True,
+ num_iters=500,
+ optimize_percent = 400,
+ ),
+ TestCase(
+ name="KPFusedGather_input_3D_case_1",
+ op_fn=KPFusedGather_graph,
+ input_fn=build_input_3D_case_1,
+ check_fn = KPFusedGather_check_fn,
+ fused_op_name = "KPFusedGather",
+ is_fused = False,
+ ),
+ TestCase(
+ name="KPFusedGather_input_3D_case_2",
+ op_fn=KPFusedGather_graph,
+ input_fn=build_input_3D_case_2,
+ check_fn = KPFusedGather_check_fn,
+ fused_op_name = "KPFusedGather",
+ is_fused = False,
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,111 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedSparseDynamicStitchOp(x, emb_tables, num_tables):
+ x_1 = tf.reshape(x, shape=[-1]) # 将输入 x 展平成一维向量 x_1
+ group_ids = tf.math.floormod(x_1, num_tables)
+ group_ids = tf.cast(group_ids, dtype=np.int32)
+ chunk_indices = tf.math.floordiv(x_1, num_tables)
+ original_indices = tf.range(0, tf.size(x_1), 1)
+ a = tf.dynamic_partition(original_indices, group_ids, num_partitions=num_tables)
+ b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=num_tables)
+ c = [tf.gather(emb_tables[i], b[i]) for i in range(num_tables)]
+ d = tf.raw_ops.ParallelDynamicStitch(indices=a, data=c)
+ return d
+
+ def KPFusedSparseDynamicStitch_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ x_shape = input['x'].shape
+ lst = list(x_shape) # 1. 转 list
+ lst[1] = None
+ x_shape = tuple(lst)
+ num_tables = meta["num_tables"]
+ x = tf.compat.v1.placeholder(tf.int64, shape=x_shape, name="x")
+
+ emb_tables = []
+ feed_dict = {x: input['x']}
+
+ for i in range(num_tables):
+ placeholder = tf.compat.v1.placeholder(tf.float32, shape=input[f'emb_table_{i}'].shape, name=f'emb_table_{i}')
+ emb_tables.append(placeholder)
+ feed_dict[placeholder] = input[f'emb_table_{i}']
+
+ result = KPFusedSparseDynamicStitchOp(x, emb_tables, num_tables)
+ return result, feed_dict
+
+ def build_input_case_1(meta):
+ input = {}
+ num_tables = meta["num_tables"]
+ emb_dim = 10
+ max_val = float('inf')
+
+ for i in range(num_tables):
+ N = np.random.randint(1000000, 44739244)
+ max_val = min(N, max_val)
+ input[f'emb_table_{i}'] = np.random.rand(N, emb_dim).astype(np.float32)
+
+ input["x"] = np.random.randint(0, num_tables * max_val, size=(1000, num_tables)).astype(np.int32) # 12个元素,每个元素范围0-143
+ return input
+
+ def build_input_case_2(meta):
+ input = {}
+ num_tables = meta["num_tables"]
+ emb_dim = 10
+ max_val = float('inf')
+
+ for i in range(num_tables):
+ N = np.random.randint(1000000, 44739244)
+ max_val = min(N, max_val)
+ input[f'emb_table_{i}'] = np.random.rand(N, emb_dim).astype(np.float32)
+
+ input["x"] = np.random.randint(0, num_tables * max_val, size=(10, 100, num_tables)).astype(np.int32) # 12个元素,每个元素范围0-143
+ return input
+
+ def build_input_3D_case_1(meta):
+ input = {}
+ num_tables = meta["num_tables"]
+ emb_dim = 10
+ max_val = float('inf')
+
+ for i in range(num_tables):
+ N = np.random.randint(1000000, 44739244)
+ max_val = min(N, max_val)
+ input[f'emb_table_{i}'] = np.random.rand(N, emb_dim, 1).astype(np.float32)
+
+ input["x"] = np.random.randint(0, num_tables * max_val, size=(1000, num_tables)).astype(np.int32) # 12个元素,每个元素范围0-143
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedSparseDynamicStitch_case_1",
+ op_fn=KPFusedSparseDynamicStitch_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseDynamicStitch",
+ start_op_name="PartitionedCall_1/Reshape",
+ end_op_name="PartitionedCall_1/ParallelDynamicStitch",
+ is_fused=True,
+ num_iters=200,
+ optimize_percent=50,
+ meta={"num_tables": 12},
+ ),
+ TestCase(
+ name="KPFusedSparseDynamicStitch_case_2",
+ op_fn=KPFusedSparseDynamicStitch_graph,
+ input_fn=build_input_case_2,
+ fused_op_name="KPFusedSparseDynamicStitch",
+ is_fused=True,
+ meta={"num_tables": 12},
+ ),
+ TestCase(
+ name="KPFusedSparseDynamicStitch_3D_case_1",
+ op_fn=KPFusedSparseDynamicStitch_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedSparseDynamicStitch",
+ is_fused=False,
+ meta={"num_tables": 12},
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,132 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedSparseReshapeOp(slice_input, begin, newshape, pack_const, dims):
+ if dims == 3:
+ end = [0, 0, 2]
+ strides = [1, 1, 1]
+ else:
+ end = [0, 2]
+ strides = [1, 1]
+
+ slice67_out = tf.strided_slice(
+ slice_input,
+ begin=begin,
+ end=end,
+ strides=strides,
+ begin_mask=1,
+ end_mask=1,
+ shrink_axis_mask=2
+ )
+
+ slice67_out = tf.reshape(slice67_out, [-1, 1])
+ shape_out = tf.shape(slice_input)
+ slice57_out = tf.strided_slice(
+ shape_out,
+ begin=[0],
+ end=[1],
+ strides=[1],
+ shrink_axis_mask=1
+ )
+
+ input_shape = tf.stack([slice57_out, pack_const])
+ input_shape = tf.cast(input_shape, tf.int64)
+
+ range_out = tf.range(0, slice57_out, 1)
+ range_out = tf.reshape(range_out, [-1, 1])
+ range_out_64 = tf.cast(range_out, dtype=tf.int64)
+ concat_out = tf.concat([range_out_64, slice67_out], axis=-1)
+
+ values = np.arange(slice_input.shape[0], dtype=np.float32)
+
+ sparse_tensor = tf.SparseTensor(
+ indices=concat_out,
+ values=values,
+ dense_shape=input_shape
+ )
+ sparse_tensor_out = tf.sparse.reshape(sparse_tensor, newshape)
+ return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out
+
+ def KPFusedSparseReshape_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ slice_shape = input['slice_input'].shape
+ lst = list(slice_shape) # 1. 转 list
+ lst[-1] = None
+ slice_shape = tuple(lst)
+ slice_input = tf.compat.v1.placeholder(tf.int64, shape=slice_shape, name="slice_input")
+ begin = tf.compat.v1.placeholder(tf.int32, shape=input['begin'].shape, name="begin")
+ newshape = tf.compat.v1.placeholder(tf.int64, shape=input['newshape'].shape, name="newshape")
+ pack_const = tf.compat.v1.placeholder(tf.int32, shape=(), name="pack_const")
+
+ feed_dict = {
+ slice_input: input['slice_input'],
+ begin: input['begin'],
+ newshape: input['newshape'],
+ pack_const: input['pack_const']
+ }
+
+ result = KPFusedSparseReshapeOp(slice_input, begin, newshape, pack_const, len(slice_shape))
+ return result, feed_dict
+
+ def build_input_case_1(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["slice_input"] = np.array([[0, 0], [0, 1], [1, 2], [3, 4]]).astype(np.int64)
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["newshape"] = np.array([2, 4]).astype(np.int64)
+ input["pack_const"] = 2
+ return input
+
+ def build_input_case_2(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["slice_input"] = np.array([[0, 1]]).astype(np.int64)
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["newshape"] = np.array([-1, 1]).astype(np.int64)
+ input["pack_const"] = 1
+ return input
+
+ def build_input_3D_case_1(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["slice_input"] = np.array([[[0, 0], [0, 1], [1, 2], [3, 4]], [[1, 2], [2, 3], [3, 4], [4, 4]]]).astype(np.int64)
+ input["begin"] = np.array([0, 0, 1]).astype(np.int32)
+ input["newshape"] = np.array([2, 2]).astype(np.int64)
+ input["pack_const"] = 2
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedSparseReshape_case_1",
+ op_fn=KPFusedSparseReshape_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseReshape",
+ start_op_name="PartitionedCall_1/StridedSlice",
+ end_op_name="PartitionedCall_1/SparseReshape",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=400,
+ ),
+ TestCase(
+ name="KPFusedSparseReshape_case_2",
+ op_fn=KPFusedSparseReshape_graph,
+ input_fn=build_input_case_2,
+ fused_op_name="KPFusedSparseReshape",
+ start_op_name="PartitionedCall_3/StridedSlice",
+ end_op_name="PartitionedCall_3/SparseReshape",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=800,
+ ),
+ TestCase(
+ name="KPFusedSparseReshape_3D_case_1",
+ op_fn=KPFusedSparseReshape_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedSparseReshape",
+ is_fused=False,
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,203 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedSparseSegmentReduceNonzeroOp(data, indices, slice_input, begin, end, strides, is_mean):
+ shrink_axis_mask = 2 ** len(begin) - 2 #slice_out需要仅保留一维
+ slice_out = tf.strided_slice(
+ slice_input,
+ begin= begin,
+ end= end,
+ strides= strides,
+ begin_mask=1,
+ end_mask=1,
+ shrink_axis_mask=shrink_axis_mask
+ )
+
+ segment_ids = tf.cast(slice_out, dtype=tf.int32)
+
+ if is_mean:
+ sparseseg_out = tf.sparse.segment_mean(
+ data = data,
+ indices = indices,
+ segment_ids= segment_ids
+ )
+ else:
+ sparseseg_out = tf.sparse.segment_sum(
+ data = data,
+ indices = indices,
+ segment_ids= segment_ids
+ )
+ zero = tf.zeros_like(sparseseg_out)
+ notequal = tf.not_equal(x=sparseseg_out, y = zero)
+ where_out = tf.where(notequal)
+ output_shape = tf.cast(where_out, dtype=tf.int32)
+ output_data = tf.gather_nd(params=sparseseg_out, indices=where_out)
+ shape = tf.shape(sparseseg_out, out_type=tf.int64)
+ output_ids = tf.cast(shape, dtype=tf.int32)
+
+ return output_shape, output_ids, output_data
+
+ def KPFusedSparseSegmentReduceNonzero_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ indices_type = np.int32
+ if "indices_type" in meta.keys():
+ indices_type = meta["indices_type"]
+ data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+ indices = tf.compat.v1.placeholder(indices_type, shape=input['indices'].shape, name="indices")
+ slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+ begin = tf.constant(input['begin'], dtype=tf.int32)
+ end = tf.constant(input['end'], dtype=tf.int32)
+ strides = tf.compat.v1.placeholder(tf.int32, shape=input['strides'].shape, name="strides")
+ is_mean = meta['is_mean']
+
+ feed_dict = {
+ data: input['data'],
+ indices: input['indices'],
+ slice_input: input['slice_input'],
+ begin: input['begin'],
+ end: input['end'],
+ strides: input['strides']
+ }
+
+ result = KPFusedSparseSegmentReduceNonzeroOp(data, indices, slice_input, begin, end, strides, is_mean)
+ return result, feed_dict
+
+ def build_input_case_1(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ indices_type = np.int32
+ if "indices_type" in meta.keys():
+ indices_type = meta["indices_type"]
+ data = np.random.rand(1449).astype(np.float32) * 10
+ zero_prob = 0.3
+ mask = np.random.rand(1449) > zero_prob
+ data[~mask] = 0
+ input["data"] = data
+ input["indices"] = np.random.randint(0, 1449, size=5742, dtype=indices_type)
+
+ start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+ end_points = start_points + np.random.randint(1, 100, size=5742)
+ end_points = np.minimum(end_points, 15661)
+ slice_input = np.column_stack((start_points, end_points))
+ slice_input[:, 1] = slice_input[:, 0]
+ input["slice_input"] = slice_input
+
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2]).astype(np.int32)
+ input["strides"] = np.array([1, 2]).astype(np.int32)
+ return input
+
+ def build_input_3D_case_1(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ data = np.random.rand(1449, 2, 3).astype(np.float32) * 10
+ zero_prob = 0.3
+ mask = np.random.rand(1449, 2, 3) > zero_prob
+ data[~mask] = 0
+ input["data"] = data
+ input["indices"] = np.random.randint(0, 1449, size=5742, dtype=np.int32)
+
+ start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+ end_points = start_points + np.random.randint(1, 100, size=5742)
+ end_points = np.minimum(end_points, 15661)
+ slice_input = np.column_stack((start_points, end_points))
+ slice_input[:, 1] = slice_input[:, 0]
+ input["slice_input"] = slice_input
+
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2]).astype(np.int32)
+ input["strides"] = np.array([1, 2]).astype(np.int32)
+ return input
+
+ def build_input_3D_case_2(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ data = np.random.rand(1449).astype(np.float32) * 10
+ zero_prob = 0.3
+ mask = np.random.rand(1449) > zero_prob
+ data[~mask] = 0
+ input["data"] = data
+ input["indices"] = np.random.randint(0, 1449, size=5742, dtype=np.int32)
+
+ start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+ end_points = start_points + np.random.randint(1, 100, size=5742)
+ end_points = np.minimum(end_points, 15661)
+ slice_input = np.column_stack((start_points, end_points))
+ slice_input[:, 1] = slice_input[:, 0]
+ input["slice_input"] = np.stack([slice_input, slice_input], axis=2) #3D数据
+
+ input["begin"] = np.array([0, 1, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2, 1]).astype(np.int32)
+ input["strides"] = np.array([1, 2, 1]).astype(np.int32)
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_sum_case_1",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ start_op_name="PartitionedCall_1/StridedSlice",
+ end_op_name="PartitionedCall_1/GatherNd",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=30,
+ meta = {"is_mean": False}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_mean_case_1",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ start_op_name="PartitionedCall_3/StridedSlice",
+ end_op_name="PartitionedCall_3/GatherNd",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=30,
+ meta = {"is_mean": True}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_sum_case_1_int64",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ start_op_name="PartitionedCall_5/StridedSlice",
+ end_op_name="PartitionedCall_5/GatherNd",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=30,
+ meta = {"is_mean": False, "indices_type": np.int64}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_mean_case_1_int64",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ start_op_name="PartitionedCall_7/StridedSlice",
+ end_op_name="PartitionedCall_7/GatherNd",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=30,
+ meta = {"is_mean": True, "indices_type": np.int64}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_sum_3D_case_1",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ is_fused=False,
+ meta = {"is_mean": False}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduceNonzero_sum_3D_case_2",
+ op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+ input_fn=build_input_3D_case_2,
+ fused_op_name="KPFusedSparseSegmentReduceNonzero",
+ is_fused=False,
+ meta = {"is_mean": True}
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,165 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedSparseSegmentReduceOp(data, indices, slice_input, begin, end, strides, is_mean):
+ shrink_axis_mask = 2 ** len(begin) - 2 #slice_out需要仅保留一维
+ slice_out = tf.strided_slice(
+ slice_input,
+ begin=begin,
+ end=end,
+ strides=strides,
+ begin_mask=1,
+ end_mask=1,
+ shrink_axis_mask=shrink_axis_mask
+ )
+
+ segment_ids = tf.cast(slice_out, dtype=tf.int32)
+ if is_mean:
+ output = tf.sparse.segment_mean(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids
+ )
+ else:
+ output = tf.sparse.segment_sum(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids
+ )
+
+ output_shape = tf.shape(output)
+ slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1], shrink_axis_mask=1)
+
+ return output, slice_out
+
+ def KPFusedSparseSegmentReduce_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ indices_type = np.int32
+ if "indices_type" in meta.keys():
+ indices_type = meta["indices_type"]
+ data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+ indices = tf.compat.v1.placeholder(indices_type, shape=input['indices'].shape, name="indices")
+ slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+ begin = tf.constant(input['begin'], dtype=tf.int32)
+ end = tf.constant(input['end'], dtype=tf.int32)
+ strides = tf.compat.v1.placeholder(tf.int32, shape=input['strides'].shape, name="strides")
+ is_mean = meta['is_mean']
+
+ feed_dict = {
+ data: input['data'],
+ indices: input['indices'],
+ slice_input: input['slice_input'],
+ begin: input['begin'],
+ end: input['end'],
+ strides: input['strides']
+ }
+
+ result = KPFusedSparseSegmentReduceOp(data, indices, slice_input, begin, end, strides, is_mean)
+ return result, feed_dict
+
+ def build_input_case_1(meta):
+ input = {}
+ indices_type = np.int32
+ if "indices_type" in meta.keys():
+ indices_type = meta["indices_type"]
+ # 根据核心函数入参名生成输入数据
+ input["data"] = np.random.rand(4, 3).astype(np.float32)
+ input["indices"] = np.array([0, 1, 2]).astype(indices_type)
+ input["slice_input"] = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64)
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2]).astype(np.int32)
+ input["strides"] = np.array([1, 2]).astype(np.int32)
+ return input
+
+ def build_input_3D_case_1(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["data"] = np.random.rand(4, 3, 2).astype(np.float32)
+ input["indices"] = np.array([0, 1, 2]).astype(np.int32)
+ input["slice_input"] = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64)
+ input["begin"] = np.array([0, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2]).astype(np.int32)
+ input["strides"] = np.array([1, 2]).astype(np.int32)
+ return input
+
+ def build_input_3D_case_2(meta):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["data"] = np.random.rand(4, 3).astype(np.float32)
+ input["indices"] = np.array([0, 1]).astype(np.int32)
+ input["slice_input"] = np.array([[[0, 0], [0, 2], [1, 2]], [[0, 1], [1, 2], [2, 2]]], dtype=np.int64)
+ input["begin"] = np.array([0, 1, 1]).astype(np.int32)
+ input["end"] = np.array([0, 2, 1]).astype(np.int32)
+ input["strides"] = np.array([1, 2, 1]).astype(np.int32)
+ return input
+
+ return [
+ TestCase(
+ name="KPFusedSparseSegmentReduce_sum_case_1",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ start_op_name="PartitionedCall_1/StridedSlice",
+ end_op_name="PartitionedCall_1/StridedSlice_1",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=200,
+ meta = {"is_mean": False}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduce_mean_case_1",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ start_op_name="PartitionedCall_3/StridedSlice",
+ end_op_name="PartitionedCall_3/StridedSlice_1",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=200,
+ meta = {"is_mean": True}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduce_sum_case_1_int64",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ start_op_name="PartitionedCall_5/StridedSlice",
+ end_op_name="PartitionedCall_5/StridedSlice_1",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=200,
+ meta = {"is_mean": False, "indices_type": np.int64}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduce_mean_case_1_int64",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ start_op_name="PartitionedCall_7/StridedSlice",
+ end_op_name="PartitionedCall_7/StridedSlice_1",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=200,
+ meta = {"is_mean": True, "indices_type": np.int64}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduce_sum_3D_case_1",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_3D_case_1,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ is_fused=False,
+ meta = {"is_mean": False}
+ ),
+ TestCase(
+ name="KPFusedSparseSegmentReduce_sum_3D_case_2",
+ op_fn=KPFusedSparseSegmentReduce_graph,
+ input_fn=build_input_3D_case_2,
+ fused_op_name="KPFusedSparseSegmentReduce",
+ is_fused=False,
+ meta = {"is_mean": True}
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,95 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ @tf.function
+ def KPFusedSparseSelectOp(input_a, input_b, input_c, greater, equal1, equal2, equal3):
+ a = tf.reshape(input_a, [-1, 1])
+ b = tf.reshape(input_b, [-1, 1])
+ c = tf.reshape(input_c, [-1, 1])
+ output_x = a
+
+ greater_a = tf.greater(a, greater)
+ shape_reshape_a1 = tf.shape(a)
+ fill_a1 = tf.fill(shape_reshape_a1, tf.constant(1, dtype=tf.float32))
+ realdiv = tf.realdiv(fill_a1, tf.constant(1, dtype=tf.float32))
+ cast_a = tf.cast(greater_a, tf.float32)
+ shape_a = tf.shape(cast_a)
+ fill_a = tf.fill(shape_a, tf.constant(1, dtype=tf.float32))
+ equal_4563 = tf.equal(b, equal1)
+ equal_10831 = tf.equal(b, equal2)
+ equal_3 = tf.equal(c, equal3)
+ select_1 = tf.where(equal_4563, fill_a, cast_a)
+ select_2 = tf.where(equal_10831, fill_a, select_1)
+ output_y = select_2
+ select_3 = tf.where(equal_3, realdiv, fill_a1)
+ output_z = tf.concat([select_2, select_3], axis=-1)
+ return output_x, output_y, output_z
+
+ def KPFusedSparseSelect_graph(input, meta):
+ # 根据核心函数入参名创建placeholder
+ input_a = tf.compat.v1.placeholder(tf.int32, shape=input['input_a'].shape, name="input_a")
+ input_b = tf.compat.v1.placeholder(tf.int32, shape=input['input_b'].shape, name="input_b")
+ input_c = tf.compat.v1.placeholder(tf.int32, shape=input['input_c'].shape, name="input_c")
+ greater = tf.constant(input['greater'], dtype=tf.int32)
+ equal1 = tf.constant(input['equal1'], dtype=tf.int32)
+ equal2 = tf.constant(input['equal2'], dtype=tf.int32)
+ equal3 = tf.constant(input['equal3'], dtype=tf.int32)
+
+ feed_dict = {
+ input_a: input['input_a'],
+ input_b: input['input_b'],
+ input_c: input['input_c'],
+ greater: input['greater'],
+ equal1: input['equal1'],
+ equal2: input['equal2'],
+ equal3: input['equal3']
+ }
+
+ result = KPFusedSparseSelectOp(input_a, input_b, input_c, greater, equal1, equal2, equal3)
+ return result, feed_dict
+
+ def build_input_case(a_shape, b_shape, c_shape):
+ input = {}
+ # 根据核心函数入参名生成输入数据
+ input["input_a"] = np.random.randint(0, 100, size=a_shape).astype(np.int32)
+ input["input_b"] = np.random.randint(0, 100, size=b_shape).astype(np.int32)
+ input["input_c"] = np.random.randint(0, 100, size=c_shape).astype(np.int32)
+ input["greater"] = np.array(0, dtype=np.int32)
+ input["equal1"] = np.array(4563, dtype=np.int32)
+ input["equal2"] = np.array(10831, dtype=np.int32)
+ input["equal3"] = np.array(3, dtype=np.int32)
+ return input
+
+ def build_input_case_1(meta):
+ return build_input_case((100, 10), (10, 100), (20, 50))
+
+ def build_input_case_2(meta):
+ return build_input_case((50, 50, 50), (50, 50, 50), (50, 50, 50))
+
+ return [
+ TestCase(
+ name="KPFusedSparseSelect_case_1",
+ op_fn=KPFusedSparseSelect_graph,
+ input_fn=build_input_case_1,
+ fused_op_name="KPFusedSparseSelect",
+ start_op_name="PartitionedCall_1/Reshape",
+ end_op_name="PartitionedCall_1/concat",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=400,
+ ),
+ TestCase(
+ name="KPFusedSparseSelect_case_2",
+ op_fn=KPFusedSparseSelect_graph,
+ input_fn=build_input_case_2,
+ fused_op_name="KPFusedSparseSelect",
+ start_op_name="PartitionedCall_3/Reshape",
+ end_op_name="PartitionedCall_3/concat",
+ is_fused=True,
+ num_iters=1000,
+ optimize_percent=600,
+ ),
+ ]
\ No newline at end of file
@@ -228,6 +228,14 @@ tf_gen_op_strict_wrapper_private_py(
],
)
+tf_gen_op_strict_wrapper_private_py(
+ name = "embedding_fused_ops_gen",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:embedding_fused_ops_op_lib",
+ ],
+)
+
tf_gen_op_strict_wrapper_private_py(
name = "collective_ops_gen",
visibility = ["//tensorflow:internal"],