diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b916623fd..071a82a4e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -673,7 +673,13 @@ cc_library(
     ]) + if_tensorrt([
         "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels",
         "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
-    ]) + tf_tpu_dependencies() + tf_dtensor_tpu_dependencies(),
+    ]) + tf_tpu_dependencies() + tf_dtensor_tpu_dependencies()
+    + select({
+        "@platforms//cpu:aarch64": [
+            "//tensorflow/core/kernels:embedding_fused_ops"
+        ],
+        "//conditions:default": [],
+    }),
 )
 
 cc_library(
@@ -920,6 +926,7 @@ filegroup(
         "candidate_sampling_ops_op_lib",
         "checkpoint_ops_op_lib",
         "clustering_ops_op_lib",
+        "embedding_fused_ops_op_lib",
         "collective_ops_op_lib",
         "control_flow_ops_op_lib",
         "count_ops_op_lib",
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt
new file mode 100644
index 000000000..bb221ad0b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingActionIdGather"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPadding.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPadding.pbtxt
new file mode 100644
index 000000000..5aef3786a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPadding.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingPadding"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt
new file mode 100644
index 000000000..17e7ebc3c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingPaddingFast"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedGather.pbtxt
new file mode 100644
index 000000000..63c4eef84
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedGather.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedGather"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedSparseDynamicStitch.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseDynamicStitch.pbtxt
new file mode 100644
index 000000000..5b01d70cf
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseDynamicStitch.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseDynamicStitch"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedSparseReshape.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseReshape.pbtxt
new file mode 100644
index 000000000..8fe3c2f6d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseReshape.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseReshape"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduce.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduce.pbtxt
new file mode 100644
index 000000000..b13aa7dc5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduce.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSegmentReduce"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt
new file mode 100644
index 000000000..c0945afe2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSegmentReduceNonzero"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSelect.pbtxt b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSelect.pbtxt
new file mode 100644
index 000000000..bafc9157e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_KPFusedSparseSelect.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSelect"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt
new file mode 100644
index 000000000..bb221ad0b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingActionIdGather.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingActionIdGather"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPadding.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPadding.pbtxt
new file mode 100644
index 000000000..5aef3786a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPadding.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingPadding"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt
new file mode 100644
index 000000000..17e7ebc3c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedEmbeddingPaddingFast.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedEmbeddingPaddingFast"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedGather.pbtxt
new file mode 100644
index 000000000..63c4eef84
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedGather.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedGather"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedSparseDynamicStitch.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseDynamicStitch.pbtxt
new file mode 100644
index 000000000..5b01d70cf
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseDynamicStitch.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseDynamicStitch"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedSparseReshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseReshape.pbtxt
new file mode 100644
index 000000000..8fe3c2f6d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseReshape.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseReshape"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduce.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduce.pbtxt
new file mode 100644
index 000000000..b13aa7dc5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduce.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSegmentReduce"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt
new file mode 100644
index 000000000..c0945afe2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSegmentReduceNonzero.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSegmentReduceNonzero"
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSelect.pbtxt b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSelect.pbtxt
new file mode 100644
index 000000000..bafc9157e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_KPFusedSparseSelect.pbtxt
@@ -0,0 +1,3 @@
+op {
+  graph_op_name: "KPFusedSparseSelect"
+}
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index ecd559734..52fa9abbb 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -904,7 +904,11 @@ tf_kernel_library(
         "//tensorflow/core/grappler/utils:symbolic_shapes",
         "//tensorflow/core/grappler/utils:topological_sort",
         "@com_google_absl//absl/container:flat_hash_set",
-    ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]),
+    ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"])
+    + select({
+        "@platforms//cpu:aarch64": ["//tensorflow/core/grappler/optimizers/graph_optimizer:annc_graph_opt"],
+        "//conditions:default": [],
+    }),
 )
 
 tf_cuda_cc_test(
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer/BUILD b/tensorflow/core/grappler/optimizers/graph_optimizer/BUILD
new file mode 100644
index 000000000..7d5ee228c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer/BUILD
@@ -0,0 +1,21 @@
+package(
+    default_visibility = [
+        "//visibility:public",
+    ],
+    licenses = ["notice"],
+)
+
+cc_library(
+    name = "annc_graph_opt",
+    srcs = glob(["*.cc"]),
+    hdrs = glob(["*.h"]),
+    linkstatic = True,
+    alwayslink = True,
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core/grappler:graph_view",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler/costs:graph_properties",
+    ],
+)
\ No newline at end of file
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.cc b/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.cc
new file mode 100644
index 000000000..2bd557d23
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.cc
@@ -0,0 +1,791 @@
+#include "graph_opt.h"
+
+using namespace tensorflow;
+using namespace tensorflow::grappler;
+
+namespace annc {
+void update_node_indexes(const GraphDef* graph,
+                         std::unordered_map<std::string, int>& node_indexes) {
+  for (int i = 0; i < graph->node_size(); ++i) {
+    node_indexes[graph->node(i).name()] = i;
+  }
+}
+
+void GraphOptimizer::register_rewriter(
+    std::unique_ptr<PatternRewriter> rewriter) {
+  rewriters_.push_back(std::move(rewriter));
+}
+
+void GraphOptimizer::optimize() {
+  update_node_indexes(graph_, node_indexes_);
+  int node_index = 0;
+  const int node_size = graph_->node_size();
+  while (node_index < node_size) {
+    const NodeDef& node = graph_->node(node_index);
+    const std::string& node_name = node.name();
+    for (auto& rewriter : rewriters_) {
+      if (rewriter->match_and_rewrite(&node, graph_, props_, node_indexes_)) {
+        update_node_indexes(graph_, node_indexes_);
+        const std::string new_node_name = node_name + fusion_appendix;
+        node_index = node_indexes_.at(new_node_name);
+        break;
+      }
+    }
+    node_index++;
+  }
+}
+
+std::string get_node_name(const std::string& name) {
+  size_t colon_pos = name.find_last_of(':');
+  std::string node_name = name;
+  if (colon_pos != std::string::npos) {
+    node_name = name.substr(0, colon_pos);
+  }
+  return node_name;
+}
+
+void set_fusedop_attributes(NodeDef* fused,
+                            const absl::Span<const absl::string_view> fused_ops,
+                            int num_args = 1, float epsilon = 0.0) {
+  auto* attr = fused->mutable_attr();
+  SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
+  SetAttrValue(num_args, &(*attr)["num_args"]);
+  SetAttrValue(epsilon, &(*attr)["epsilon"]);  // required only for BatchNorm
+}
+
+const NodeDef* PatternRewriter::get_node(const std::string& name) {
+  const std::string node_name = get_node_name(name);
+  const int node_index = indexes_->at(node_name);
+  return &graph_->node(node_index);
+}
+
+NodeDef* PatternRewriter::get_mutable_node(const std::string& name) {
+  const std::string node_name = get_node_name(name);
+  if (indexes_->find(node_name) == indexes_->end()) return nullptr;
+  const int node_index = indexes_->at(node_name);
+  return graph_->mutable_node(node_index);
+}
+
+NodeDef* PatternRewriter::get_operand(const NodeDef* node,
+                                      std::string op_type) {
+  for (int i = 0; i < node->input_size(); ++i) {
+    NodeDef* operand = get_mutable_node(node->input(i));
+    if (operand != nullptr && operand->op() == op_type) return operand;
+  }
+  return nullptr;
+}
+
+const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index,
+                                         const std::string& op_type) {
+  std::string node_name = node->name();
+  if (index) node_name = node_name + ":" + std::to_string(index);
+  for (int i = 0; i < graph_->node_size(); ++i) {
+    const NodeDef* node = &graph_->node(i);
+    for (int j = 0; j < node->input_size(); ++j) {
+      if (node->input(j) == node_name && node->op() == op_type) {
+        return node;
+      }
+    }
+  }
+  return nullptr;
+}
+
+void PatternRewriter::replace_all_users_with(const NodeDef* old_node,
+                                             int old_index,
+                                             const NodeDef* new_node,
+                                             int new_index, GraphDef* graph) {
+  std::string old_name = old_node->name();
+  if (old_index) old_name = old_name + ":" + std::to_string(old_index);
+  std::string new_name = new_node->name();
+  if (new_index) new_name = new_name + ":" + std::to_string(new_index);
+  for (int i = 0; i < graph->node_size(); ++i) {
+    NodeDef* node = graph->mutable_node(i);
+    for (int j = 0; j < node->input_size(); ++j) {
+      if (node->input(j) == old_name) {
+        node->set_input(j, new_name);
+      }
+    }
+  }
+}
+
+class KPFusedSparseDynamicStitchRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedSparseDynamicStitch"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" &&
+                  node->input_size() % 2 == 0)
+    int num_inputs = node->input_size();
+    int num_partitions = num_inputs / 2;
+    // left branch
+    const NodeDef* partition = get_node(node->input(0));
+    CHECK_NODE_OK(partition->op() == "DynamicPartition" &&
+                  partition->input_size() == 2)
+    const NodeDef* range = get_node(partition->input(0));
+    CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
+    // Range start=0, delta=1
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(range->input(0)), {0}))
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(range->input(2)), {1}))
+    const NodeDef* size = get_node(range->input(1));
+    CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1)
+    const NodeDef* cast = get_node(partition->input(1));
+    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+    const NodeDef* floor_mod = get_node(cast->input(0));
+    CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2)
+    CHECK_NODE_OK(check_const_value<int64_t>(
+        get_mutable_node(floor_mod->input(1)), {num_partitions}))
+
+    CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions}))
+    CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions}))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(size->input(0));
+    // right branch
+    for (int i = num_partitions; i < num_inputs; ++i) {
+      const NodeDef* gather = get_node(node->input(i));
+      CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
+      // Gather axis=0
+      CHECK_NODE_OK(
+          check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+      const NodeDef* partition_1 = get_node(gather->input(1));
+      CHECK_NODE_OK(partition_1->op() == "DynamicPartition" &&
+                    partition_1->input_size() == 2)
+      CHECK_NODE_OK(
+          check_int_attr(partition_1, "num_partitions", {num_partitions}))
+      const NodeDef* floor_div = get_node(partition_1->input(0));
+      CHECK_NODE_OK(floor_div->op() == "FloorDiv" &&
+                    floor_div->input_size() == 2)
+      CHECK_NODE_OK(check_const_value<int64_t>(
+          get_mutable_node(floor_div->input(1)), {num_partitions}))
+      CHECK_NODE_OK(check_input_dims(props, gather, 0, 2));
+      fused_node->add_input(gather->input(0));
+    }
+    (*fused_node->mutable_attr())["N"].set_i(num_partitions);
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(node, 0, fused_node, 0, graph);
+    return true;
+  }
+};
+
+class KPFusedSparseSegmentReduceRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedSparseSegmentReduce"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
+    const NodeDef* shape = get_node(node->input(0));
+    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+    const NodeDef* ss_reduce = get_node(shape->input(0));
+    CHECK_NODE_OK(ss_reduce->input_size() == 3)
+    AttrValue combiner;
+    if (ss_reduce->op() == "SparseSegmentMean")
+      combiner.set_i(1);
+    else if (ss_reduce->op() == "SparseSegmentSum")
+      combiner.set_i(0);
+    else
+      return false;
+    const NodeDef* cast = get_node(ss_reduce->input(2));
+    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+    const NodeDef* strided_slice = get_node(cast->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+                  strided_slice->input_size() == 4)
+
+    // check fusion conditions
+    CHECK_NODE_OK(
+        check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+    CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1}))
+    CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
+
+    CHECK_NODE_OK(check_input_dims(props, ss_reduce, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(ss_reduce->input(0));      // input_tensor
+    fused_node->add_input(ss_reduce->input(1));      // indices
+    fused_node->add_input(strided_slice->input(0));  // slice_input
+    fused_node->add_input(strided_slice->input(1));  // begin
+    fused_node->add_input(node->input(1));           // begin_1
+    AddNodeAttr("combiner", combiner, fused_node);
+    AddNodeAttr("Tidx", ss_reduce->attr().at("Tidx"), fused_node);
+
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(ss_reduce, 0, fused_node, 0, graph);
+    replace_all_users_with(node, 0, fused_node, 1, graph);
+    return true;
+  }
+};
+
+class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter {
+ public:
+  std::string name() const override {
+    return "KPFusedSparseSegmentReduceNonzero";
+  }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(node->op() == "GatherNd" &&
+                  node->input_size() == 2)  // output:2
+    const NodeDef* ss_reduce = get_node(node->input(0));
+    CHECK_NODE_OK(ss_reduce->input_size() == 3)
+    AttrValue combiner;
+    if (ss_reduce->op() == "SparseSegmentMean") {
+      combiner.set_i(1);
+    } else if (ss_reduce->op() == "SparseSegmentSum") {
+      combiner.set_i(0);
+    } else {
+      return false;
+    }
+    const NodeDef* where = get_node(node->input(1));
+    CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1)
+    const NodeDef* cast = get_user(where, 0, "Cast");
+    CHECK_NODE_OK(cast != nullptr)  // output: 1
+    const NodeDef* notequal = get_node(where->input(0));
+    CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2);
+    const NodeDef* zerolike = get_node(notequal->input(1));
+    CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1)
+    const NodeDef* cast_1 = get_node(ss_reduce->input(2));
+    CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
+    const NodeDef* strided_slice = get_node(cast_1->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+                  strided_slice->input_size() == 4)
+    const NodeDef* shape = get_user(ss_reduce, 0, "Shape");
+    CHECK_NODE_OK(shape != nullptr)
+    const NodeDef* cast_2 = get_user(shape, 0, "Cast");  // output: 0
+    CHECK_NODE_OK(cast_2 != nullptr)
+
+    CHECK_NODE_OK(
+        check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+
+    CHECK_NODE_OK(check_input_dims(props, ss_reduce, 0, 1) ||
+                  check_input_dims(props, ss_reduce, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(ss_reduce->input(0));
+    fused_node->add_input(ss_reduce->input(1));
+    fused_node->add_input(strided_slice->input(0));
+    fused_node->add_input(strided_slice->input(1));
+    AddNodeAttr("combiner", combiner, fused_node);
+    AddNodeAttr("Tidx", ss_reduce->attr().at("Tidx"), fused_node);
+
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name() << "\n";
+    replace_all_users_with(cast_2, 0, fused_node, 0, graph);
+    replace_all_users_with(cast, 0, fused_node, 1, graph);
+    replace_all_users_with(node, 0, fused_node, 2, graph);
+    return true;
+  }
+};
+
+class KPFusedEmbeddingPaddingRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedEmbeddingPadding"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(IsReshape(*node))
+    const NodeDef* user = get_user(node, 0, "ConcatV2");
+    CHECK_NODE_OK(user != nullptr)
+    const NodeDef* concat = get_node(node->input(0));
+    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(concat->input(2)), {0}))
+    const NodeDef* fill = get_node(concat->input(1));
+    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+    NodeDef* pack = get_operand(fill, "Pack");
+    CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2)
+    NodeDef* fill_const = get_operand(fill, "Const");
+    CHECK_NODE_OK(fill_const != nullptr &&
+                  check_const_value<int>(fill_const, {0}))
+    NodeDef* sub = get_operand(pack, "Sub");
+    CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2)
+    const NodeDef* strided_slice = get_node(sub->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+                  strided_slice->input_size() == 4)
+    const NodeDef* cast = get_node(strided_slice->input(0));
+    CHECK_NODE_OK(IsCast(*cast))
+
+    CHECK_NODE_OK(check_input_dims(props, cast, 0, 1))
+    CHECK_NODE_OK(check_input_dims(props, concat, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, sub, 1, 0))
+    CHECK_NODE_OK(check_input_dims(props, node, 1, 1))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(cast->input(0));    // origin_shape
+    fused_node->add_input(concat->input(0));  // input
+    fused_node->add_input(sub->input(1));     // input_rows
+    fused_node->add_input(node->input(1));    // reshape_sizes
+    const NodeDef* pack_left = get_node(pack->input(0));
+    const NodeDef* pack_right = get_node(pack->input(1));
+    AddNodeAttr("T", pack->attr().at("T"), fused_node);
+    if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
+      fused_node->add_input(pack->input(0));
+    } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
+      fused_node->add_input(pack->input(1));
+    } else {
+      return false;
+    }
+
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(sub, 0, fused_node, 0, graph);
+    replace_all_users_with(node, 0, fused_node, 1, graph);
+    return true;
+  }
+};
+
+class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedEmbeddingPaddingFast"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
+    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(1)), {0}))
+    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {1}))
+    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(3)), {1}))
+    CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
+    const NodeDef* shape = get_node(node->input(0));
+    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+    const NodeDef* reshape = get_node(shape->input(0));
+    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+    const NodeDef* concat = get_node(reshape->input(0));
+    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+    const NodeDef* fill = get_node(concat->input(1));
+    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+    const NodeDef* pack = get_node(fill->input(0));
+    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+    const NodeDef* sub = get_node(pack->input(0));
+    CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2)
+    const NodeDef* strided_slice = get_node(sub->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+                  strided_slice->input_size() == 4)
+    const NodeDef* cast = get_node(strided_slice->input(0));
+    CHECK_NODE_OK(IsCast(*cast))
+
+    CHECK_NODE_OK(check_input_dims(props, cast, 0, 1))
+    CHECK_NODE_OK(check_input_dims(props, concat, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, sub, 1, 0))
+    CHECK_NODE_OK(check_input_dims(props, reshape, 1, 1))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(cast->input(0));
+    fused_node->add_input(concat->input(0));
+    fused_node->add_input(sub->input(1));
+    fused_node->add_input(reshape->input(1));
+    const NodeDef* pack_left = get_node(pack->input(0));
+    const NodeDef* pack_right = get_node(pack->input(1));
+    if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
+      fused_node->add_input(pack->input(0));
+    } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
+      fused_node->add_input(pack->input(1));
+    } else {
+      return false;
+    }
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(sub, 0, fused_node, 0, graph);
+    replace_all_users_with(node, 0, fused_node, 1, graph);
+    return true;
+  }
+};
+
+class KPFusedSparseSelectRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedSparseSelect"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
+    const NodeDef* select_0 = get_node(node->input(0));
+    CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3)
+    const NodeDef* select_1 = get_node(select_0->input(2));
+    CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3)
+    const NodeDef* fill = get_node(select_1->input(1));
+    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<float>(get_mutable_node(fill->input(1)), {1.0f}))
+    const NodeDef* cast = get_node(select_1->input(2));
+    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+    const NodeDef* equal = get_node(select_1->input(0));
+    CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2)
+    NodeDef* reshape = get_operand(equal, "Reshape");
+    CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape))
+    const NodeDef* greater = get_node(cast->input(0));
+    CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2)
+    const NodeDef* reshape_4 = get_node(greater->input(0));
+    CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(reshape_4->input(1)), {-1, 1}))
+    const NodeDef* equal_1 = get_node(select_0->input(0));
+    CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2)
+    NodeDef* reshape_1 = get_operand(equal_1, "Reshape");
+    CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1))
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
+
+    // right branch
+    const NodeDef* select_2 = get_node(node->input(1));
+    CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3)
+    const NodeDef* equal_2 = get_node(select_2->input(0));
+    CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2)
+    const NodeDef* fill_1 = get_node(select_2->input(2));
+    CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<float>(get_mutable_node(fill_1->input(1)), {1.0f}))
+    const NodeDef* reshape_2 = get_operand(equal_2, "Reshape");
+    CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2))
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(reshape_2->input(1)), {-1, 1}))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(reshape_4->input(0));
+    fused_node->add_input(reshape_1->input(0));
+    fused_node->add_input(reshape_2->input(0));
+    std::vector<const NodeDef*> const_inputs = {
+        greater,
+        equal,
+        equal_1,
+        equal_2,
+    };
+    for (const NodeDef* const_node : const_inputs) {
+      const NodeDef* left = get_node(const_node->input(0));
+      const NodeDef* right = get_node(const_node->input(1));
+      if (IsConstant(*left) || IsHostConstant(*left)) {
+        fused_node->add_input(const_node->input(0));
+      } else if (IsConstant(*right) || IsHostConstant(*right)) {
+        fused_node->add_input(const_node->input(1));
+      } else {
+        return false;
+      }
+    }
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(reshape_4, 0, fused_node, 0, graph);
+    replace_all_users_with(select_0, 0, fused_node, 1, graph);
+    replace_all_users_with(node, 0, fused_node, 2, graph);
+    return true;
+  }
+};
+
+class KPFusedGatherRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedGather"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3)
+    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {0}))
+    const NodeDef* gather = get_node(node->input(0));
+    CHECK_NODE_OK(gather->op() == "GatherV2" &&
+                  gather->input_size() == 3)  // input:0
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+    const NodeDef* unique = get_node(node->input(1));  // output:1
+    CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1)
+    const NodeDef* unique_1 = get_node(unique->input(0));
+    CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1)
+    const NodeDef* strided_slice = get_node(unique_1->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice))  // input:1 2
+    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+
+    CHECK_NODE_OK(check_input_dims(props, gather, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, strided_slice, 0, 2))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(gather->input(0));
+    fused_node->add_input(strided_slice->input(0));
+    fused_node->add_input(strided_slice->input(1));
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(unique_1, 0, fused_node, 0, graph);
+    replace_all_users_with(unique_1, 1, fused_node, 1, graph);
+    replace_all_users_with(node, 0, fused_node, 2, graph);
+    return true;
+  }
+};
+
+class KPFusedSparseReshapeRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedSparseReshape"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3)
+    const NodeDef* concat = get_node(node->input(0));
+    CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(concat->input(2)), {-1}))
+    const NodeDef* reshape = get_node(concat->input(1));
+    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(reshape->input(1)), {-1, 1}))
+    const NodeDef* strided_slice = get_node(reshape->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
+                  strided_slice->input_size() == 4)
+    CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
+    CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
+    const NodeDef* cast_1 = get_node(concat->input(0));
+    CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
+    const NodeDef* reshape_1 = get_node(cast_1->input(0));
+    CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
+    const NodeDef* range = get_node(reshape_1->input(0));
+    CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
+    // Range start=0, delta=1
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(range->input(0)), {0}))
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(range->input(2)), {1}))
+    const NodeDef* cast = get_node(node->input(1));
+    CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
+    const NodeDef* pack = get_node(cast->input(0));
+    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+    const NodeDef* strided_slice_1 = get_node(pack->input(0));
+    CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) &&
+                  strided_slice_1->input_size() == 4)
+    CHECK_NODE_OK(check_const_value<int>(
+        get_mutable_node(strided_slice_1->input(1)), {0}))
+    CHECK_NODE_OK(check_const_value<int>(
+        get_mutable_node(strided_slice_1->input(2)), {1}))
+    CHECK_NODE_OK(check_const_value<int>(
+        get_mutable_node(strided_slice_1->input(3)), {1}))
+    CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1))
+    const NodeDef* shape = get_node(strided_slice_1->input(0));
+    CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
+
+    CHECK_NODE_OK(check_input_dims(props, shape, 0, 2))
+    CHECK_NODE_OK(check_input_dims(props, pack, 1, 0))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(shape->input(0));          // slice_input
+    fused_node->add_input(strided_slice->input(1));  // begin
+    fused_node->add_input(node->input(2));           // new_shape
+    fused_node->add_input(pack->input(1));           // pack_const
+    AddNodeAttr("T", pack->attr().at("T"), fused_node);
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(node, 0, fused_node, 0, graph);
+    replace_all_users_with(node, 1, fused_node, 1, graph);
+    return true;
+  }
+};
+
+class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter {
+ public:
+  std::string name() const override { return "KPFusedEmbeddingActionIdGather"; }
+
+  bool match_and_rewrite(
+      const NodeDef* node, GraphDef* graph,
+      const GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) override {
+    graph_ = graph;
+    indexes_ = &node_indexes;
+    CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(node->input(2)), {-1}))
+    const NodeDef* reshape = get_node(node->input(0));
+    CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
+    const NodeDef* gather = get_node(reshape->input(0));
+    const NodeDef* pack_1 = get_node(reshape->input(1));
+    CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(pack_1->input(1)), {-1}))
+    CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
+    const NodeDef* gather_1 = get_node(gather->input(0));
+    CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3)
+    CHECK_NODE_OK(
+        check_const_value<int>(get_mutable_node(gather_1->input(2)), {0}))
+    const NodeDef* fill = get_node(node->input(1));
+    CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
+    CHECK_NODE_OK(check_const_value<int>(get_mutable_node(fill->input(1)), {0}))
+    const NodeDef* pack = get_node(fill->input(0));
+    CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
+
+    CHECK_NODE_OK(check_input_dims(props, gather_1, 1, 1) ||
+                  check_input_dims(props, gather_1, 1, 2))
+    CHECK_NODE_OK(check_input_dims(props, gather, 1, 1 ) ||
+                  check_input_dims(props, gather, 1, 2))
+    CHECK_NODE_OK(check_input_dims(props, gather_1, 0, 2) ||
+                  check_input_dims(props, gather_1, 0, 3))
+    CHECK_NODE_OK(check_input_dims(props, pack, 0, 0))
+    CHECK_NODE_OK(check_input_dims(props, pack, 1, 0))
+
+    auto nodes = graph->mutable_node();
+    NodeDef* fused_node = nodes->Add();
+    fused_node->set_name(node->name() + fusion_appendix);
+    fused_node->set_op(name());
+    fused_node->set_device(node->device());
+    fused_node->add_input(gather_1->input(1));  // indices1
+    fused_node->add_input(gather_1->input(0));  // params
+    fused_node->add_input(gather->input(1));    // indices2
+    fused_node->add_input(pack->input(0));      // pack_dim
+    fused_node->add_input(pack->input(1));      // pack
+    AddNodeAttr("Tindices1", gather_1->attr().at("Tindices"), fused_node);
+    AddNodeAttr("Tindices2", gather->attr().at("Tindices"), fused_node);
+    nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
+
+    VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
+            << fused_node->name();
+    replace_all_users_with(node, 0, fused_node, 0, graph);
+    return true;
+  }
+};
+
+void run_graph_optimization(GraphDef* graph, GraphProperties props) {
+  GraphOptimizer optimizer(graph, props);
+
+  const char* annc_fused_all = getenv("ANNC_FUSED_ALL");
+  const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH");
+  const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE");
+  const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING");
+  const char* annc_fused_emb_padding_fast =
+      getenv("ANNC_FUSED_EMD_PADDING_FAST");
+  const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT");
+  const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER");
+  const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE");
+  const char* annc_fused_emb_actionid_gather =
+      getenv("ANNC_FUSED_EMB_ACTIONID_GATHER");
+  const char* annc_fused_sps_reduce_nonzero =
+      getenv("ANNC_FUSED_SPS_REDUCE_NONZERO");
+
+  bool enable_all =
+      (annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0;
+
+  // default enable all rewriters
+  if (enable_all || (annc_fused_sps_stitch != nullptr &&
+                     strcmp(annc_fused_sps_stitch, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedSparseDynamicStitchRewriter>());
+  if (enable_all || (annc_fused_sps_reduce != nullptr &&
+                     strcmp(annc_fused_sps_reduce, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedSparseSegmentReduceRewriter>());
+  if (enable_all || (annc_fused_emb_padding_fast != nullptr &&
+                     strcmp(annc_fused_emb_padding_fast, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedEmbeddingPaddingFastRewriter>());
+  if (enable_all || (annc_fused_emb_padding != nullptr &&
+                     strcmp(annc_fused_emb_padding, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedEmbeddingPaddingRewriter>());
+  if (enable_all || (annc_fused_sps_select != nullptr &&
+                     strcmp(annc_fused_sps_select, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedSparseSelectRewriter>());
+  if (enable_all ||
+      (annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0))
+    optimizer.register_rewriter(std::make_unique<KPFusedGatherRewriter>());
+  if (enable_all || (annc_fused_sps_reshape != nullptr &&
+                     strcmp(annc_fused_sps_reshape, "1") == 0))
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedSparseReshapeRewriter>());
+  if (enable_all || annc_fused_emb_actionid_gather != nullptr &&
+      strcmp(annc_fused_emb_actionid_gather, "1") == 0)
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedEmbeddingActionIdGatherRewriter>());
+  if (enable_all || annc_fused_sps_reduce_nonzero != nullptr &&
+      strcmp(annc_fused_sps_reduce_nonzero, "1") == 0)
+    optimizer.register_rewriter(
+        std::make_unique<KPFusedSparseSegmentReduceNonzeroRewriter>());
+  optimizer.optimize();
+}
+}  // namespace annc
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h b/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h
new file mode 100644
index 000000000..c65127785
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h
@@ -0,0 +1,178 @@
+#ifndef ANNC_TF_GRAPH_OPT_H_
+#define ANNC_TF_GRAPH_OPT_H_
+#include <type_traits>
+#include <unordered_map>
+
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+
+namespace annc {
+#define CHECK_NODE_OK(x) \
+  if (!(x)) {            \
+    return false;        \
+  }
+
+static const std::string fusion_appendix = "/kp_fused";
+
+void update_node_indexes(const tensorflow::GraphDef* graph,
+                         std::unordered_map<std::string, int>& node_indexes);
+
+class PatternRewriter {
+ public:
+  PatternRewriter() {}
+  virtual ~PatternRewriter() = default;
+
+  virtual bool match_and_rewrite(
+      const tensorflow::NodeDef* node, tensorflow::GraphDef* graph,
+      const tensorflow::grappler::GraphProperties& props,
+      std::unordered_map<std::string, int>& node_indexes) = 0;
+
+  virtual std::string name() const { return "PatternRewriter"; };
+
+  const tensorflow::NodeDef* get_node(const std::string& name);
+  tensorflow::NodeDef* get_mutable_node(const std::string& name);
+
+  tensorflow::NodeDef* get_operand(const tensorflow::NodeDef* node, std::string op_type);
+
+  const tensorflow::NodeDef* get_user(const tensorflow::NodeDef* node, int index,
+                          const std::string& op_type);
+
+  void replace_all_users_with(const tensorflow::NodeDef* old_node, int old_index,
+                              const tensorflow::NodeDef* new_node, int new_index,
+                              tensorflow::GraphDef* graph);
+
+  bool check_input_dims(const tensorflow::grappler::GraphProperties& graph_properties,
+                      const tensorflow::NodeDef* op, int input_index,
+                      int expected_dim_size) {
+    const auto& input_props = graph_properties.GetInputProperties(op->name());
+    if (input_index >= static_cast<int>(input_props.size())) {
+      return false;
+    }
+    const tensorflow::TensorShapeProto& shape = input_props[input_index].shape();
+    std::string shape_str = "[";
+    if (shape.unknown_rank()) {
+      shape_str = "[?]";  //  rank
+    } else {
+      for (int i = 0; i < shape.dim_size(); ++i) {
+        if (i > 0) shape_str += ", ";
+        // -1 "?"
+        auto dim_size = shape.dim(i).size();
+        if (dim_size == -1) {
+          shape_str += "?";
+        } else {
+          shape_str += std::to_string(dim_size);
+        }
+      }
+      shape_str += "]";
+    }
+
+    LOG(INFO) << "  Full input shape: " << shape_str
+              << ", rank: " << (shape.unknown_rank() ? -1 : shape.dim_size())
+              << ", expected rank: " << expected_dim_size;
+
+    return shape.dim_size() == expected_dim_size;
+  }
+  bool check_const_dims(tensorflow::NodeDef* op, int dim_size) {
+    if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+          HasNodeAttr(*op, "value")))
+      return false;
+
+    tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+    const auto& shape = tensor->tensor_shape();
+    if (shape.dim_size() != static_cast<int>(dim_size)) return false;
+    return true;
+  }
+
+  bool check_const_shape(tensorflow::NodeDef* op, std::vector<int> dims) {
+    if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+          HasNodeAttr(*op, "value")))
+      return false;
+
+    tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+    const auto& shape = tensor->tensor_shape();
+    if (shape.dim_size() != static_cast<int>(dims.size())) return false;
+    for (int i = 0; i < shape.dim_size(); ++i) {
+      if (shape.dim(i).size() != dims[i]) return false;
+    }
+    return true;
+  }
+
+  template <typename T>
+  bool check_const_value(tensorflow::NodeDef* op, std::vector<T> cmp) {
+    if (!((tensorflow::grappler::IsConstant(*op) || tensorflow::grappler::IsHostConstant(*op)) &&
+          HasNodeAttr(*op, "value")))
+      return false;
+
+    tensorflow::TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
+    const auto& shape = tensor->tensor_shape();
+    int dim_size = 1;
+    for (int i = 0; i < shape.dim_size(); ++i) {
+      dim_size *= shape.dim(i).size();
+    }
+    if (dim_size < static_cast<int>(cmp.size())) return false;
+
+    if (std::is_same<T, float>::value) {
+      const float* data = tensor->mutable_float_val()->data();
+      if (data == nullptr)
+        data = reinterpret_cast<const float*>(tensor->tensor_content().data());
+      if (data == nullptr) return false;
+      for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+        if (std::fabs(data[i] - cmp[i]) >= 1e-5f) return false;
+      }
+    } else if (std::is_same<T, int>::value) {
+      const int* data = tensor->mutable_int_val()->data();
+      if (data == nullptr)
+        data = reinterpret_cast<const int*>(tensor->tensor_content().data());
+      if (data == nullptr) return false;
+      for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+        if (data[i] != cmp[i]) return false;
+      }
+    } else if (std::is_same<T, int64_t>::value) {
+      const int64_t* data = tensor->mutable_int64_val()->data();
+      if (data == nullptr)
+        data =
+            reinterpret_cast<const int64_t*>(tensor->tensor_content().data());
+      if (data == nullptr) return false;
+      for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
+        if (data[i] != cmp[i]) return false;
+      }
+    } else {
+      // data type do not support
+      return false;
+    }
+    return true;
+  }
+
+  bool check_int_attr(const tensorflow::NodeDef* op, std::string name, int value) {
+    if (HasNodeAttr(*op, name)) {
+      tensorflow::AttrValue attr = op->attr().at(name);
+      if (attr.value_case() == tensorflow::AttrValue::kI && attr.i() == value) return true;
+    }
+    return false;
+  }
+
+  tensorflow::GraphDef* graph_;
+  std::unordered_map<std::string, int>* indexes_;
+};
+
+class GraphOptimizer {
+ public:
+  GraphOptimizer(tensorflow::GraphDef* graph, tensorflow::grappler::GraphProperties graph_properties) : graph_(graph), props_(graph_properties) {}
+  virtual ~GraphOptimizer() = default;
+
+  void register_rewriter(std::unique_ptr<PatternRewriter> rewriter);
+
+  void optimize();
+
+ private:
+  tensorflow::GraphDef* graph_;
+  tensorflow::grappler::GraphProperties props_;
+  std::unordered_map<std::string, int> node_indexes_;
+  std::vector<std::unique_ptr<PatternRewriter>> rewriters_;
+};
+
+void run_graph_optimization(tensorflow::GraphDef* graph, tensorflow::grappler::GraphProperties props);
+}  // namespace annc
+#endif  // ANNC_TF_GRAPH_OPT_H_
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 3c37150f4..f6a59a68f 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -43,6 +43,9 @@ limitations under the License.
 #include "tensorflow/core/util/env_var.h"
 #include "tensorflow/core/util/use_cudnn.h"
 #include "tsl/platform/errors.h"
+#if defined(__aarch64__)
+#include "tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h"
+#endif  // __aarch64__
 #ifdef INTEL_MKL
 #include "tensorflow/core/util/mkl_heuristics.h"
 #endif  // INTEL_MKL
@@ -4941,6 +4944,19 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
   }
   TF_RETURN_IF_ERROR(mutation->Apply());
 
+#if defined(__aarch64__)
+  // ==========  infer shape  ==========
+  tensorflow::grappler::GraphProperties graph_props(mutable_item);
+  bool assume_valid_feeds = (opt_level_ == RewriterConfig::AGGRESSIVE);
+  TF_RETURN_IF_ERROR(graph_props.InferStatically(
+      assume_valid_feeds,
+      /*aggressive_shape_inference=*/false,
+      /*include_input_tensor_values=*/true,
+      /*include_output_tensor_values=*/true));
+
+  annc::run_graph_optimization(&mutable_item.graph, graph_props);
+#endif  // __aarch64__
+
   *optimized_graph = std::move(mutable_item.graph);
 
   return OkStatus();
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 7cdceb549..1844f0322 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3664,6 +3664,97 @@ tf_kernel_library(
     ]) + [":fft_impl"],
 )
 
+tf_kernel_library(
+    name = "embedding_fused_action_id_gather_op",
+    srcs = ["embedding_fused_action_id_gather.cc"],
+    deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+    name = "embedding_fused_gather_op",
+    srcs = ["embedding_fused_gather.cc"],
+    deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+    name = "embedding_fused_padding_op",
+    srcs = ["embedding_fused_padding.cc"],
+    deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+    name = "embedding_fused_sparse_dynamic_stitch_op",
+    srcs = ["embedding_fused_sparse_dynamic_stitch.cc"],
+    deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+    name = "embedding_fused_reshape_op",
+    srcs = ["embedding_fused_sparse_reshape.cc"],
+    deps = MATH_DEPS + [
+        ":reshape_util",
+    ],
+)
+
+tf_kernel_library(
+    name = "embedding_fused_sparse_segment_reduce_op",
+    srcs = ["embedding_fused_sparse_segment_reduce.cc"],
+    deps = MATH_DEPS,
+)
+
+tf_kernel_library(
+    name = "embedding_fused_sparse_segment_reduce_nonzero_op",
+    srcs = ["embedding_fused_sparse_segment_reduce_nonzero.cc"],
+    deps = MATH_DEPS + ["@com_google_absl//absl/container:flat_hash_map"],
+)
+
+tf_kernel_library(
+    name = "embedding_fused_sparse_select_op",
+    srcs = ["embedding_fused_sparse_select.cc"],
+    deps = MATH_DEPS,
+)
+
+cc_library(
+    name = "embedding_fused_ops",
+    deps = [
+        ":embedding_fused_action_id_gather_op",
+        ":embedding_fused_gather_op",
+        ":embedding_fused_padding_op",
+        ":embedding_fused_sparse_dynamic_stitch_op",
+        ":embedding_fused_reshape_op",
+        ":embedding_fused_sparse_segment_reduce_op",
+        ":embedding_fused_sparse_segment_reduce_nonzero_op",
+        ":embedding_fused_sparse_select_op",
+    ],
+)
+
+tf_cc_test(
+    name = "embedding_fused_ops_test",
+    size = "small",
+    srcs = [
+        "embedding_fused_action_id_gather_test.cc",
+        "embedding_fused_sparse_dynamic_stitch_test.cc",
+        "embedding_fused_sparse_segment_reduce_test.cc",
+        "embedding_fused_sparse_segment_reduce_nonzero_test.cc",
+        "embedding_fused_padding_test.cc",
+        "embedding_fused_sparse_select_test.cc",
+        "embedding_fused_gather_test.cc",
+        "embedding_fused_sparse_reshape_test.cc",
+    ],
+    deps = [
+        ":ops_testutil",
+        ":ops_util",
+        ":embedding_fused_ops",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_kernel_library(
     name = "reduction_ops",
     gpu_srcs = ["reduction_gpu_kernels.cu.h"],
diff --git a/tensorflow/core/kernels/embedding_fused_action_id_gather.cc b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc
new file mode 100644
index 000000000..4e1381a50
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc
@@ -0,0 +1,263 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+// Fused double gather:
+//   Step 1: temp[i, j, :] = params[indices1[i, j], :]  -> shape [I10, I11, P1]
+//   Step 2: output[i, j, k, :] = temp[indices2[i, j], k, :]  -> shape [I20, I21, I11, P1]
+// Fused: output[i, j, k, :] = params[indices1[indices2[i, j], k], :]
+template <typename Tindices1, typename Tindices2>
+static void FusedDoubleGatherImpl(OpKernelContext* context,
+    const float* params_data, const TensorShape& params_shape,
+    const Tindices1* indices1_data, const TensorShape& indices1_shape,
+    const Tindices2* indices2_data, const TensorShape& indices2_shape,
+    Tensor* output) {
+  // params shape: [P0, P1] (2D) or [P0, P1, P2] (3D)
+  // indices1 shape: [I10, I11] (2D) or [I10] (1D), values in [0, P0)
+  // indices2 shape: [I20, I21] (2D) or [I21] (1D), values in [0, I10) for 2D, [0, P0) for 1D
+  // output shape:
+  //   2D+2D: [I20, I21, I11, P_row]
+  //   1D+2D: [I20, I21, P_row]
+  //   2D+1D: [I21, I11, P_row]
+  //   1D+1D: [I21, P_row]
+  // where P_row = P1 (2D params) or P1*P2 (3D params)
+  OP_REQUIRES(context, params_shape.dims() >= 2 && params_shape.dims() <= 3,
+      errors::InvalidArgument("params must be 2D or 3D matrix"));
+  OP_REQUIRES(context, indices1_shape.dims() >= 1 && indices1_shape.dims() <= 2,
+      errors::InvalidArgument("indices1 must be 1D or 2D matrix"));
+  OP_REQUIRES(context, indices2_shape.dims() >= 1 && indices2_shape.dims() <= 2,
+      errors::InvalidArgument("indices2 must be 1D or 2D matrix"));
+
+  const int P0 = params_shape.dim_size(0);
+  // P_row: number of floats per params row (P1 for 2D, P1*P2 for 3D)
+  int64_t P_row = params_shape.dim_size(1);
+  if (params_shape.dims() == 3) {
+    P_row *= params_shape.dim_size(2);
+  }
+  
+  // Handle indices1 dimensions: 1D [I10] or 2D [I10, I11]
+  const int indices1_dims = indices1_shape.dims();
+  const int I10 = indices1_shape.dim_size(0);
+  const int I11 = (indices1_dims == 2) ? indices1_shape.dim_size(1) : 1;
+  
+  // Handle indices2 dimensions: 1D [I21] or 2D [I20, I21]
+  const int indices2_dims = indices2_shape.dims();
+  const int I20 = (indices2_dims == 2) ? indices2_shape.dim_size(0) : 1;
+  const int I21 = (indices2_dims == 2) ? indices2_shape.dim_size(1) : indices2_shape.dim_size(0);
+
+  // Build output shape based on indices dimensions
+  TensorShape output_shape;
+  if (indices2_dims == 2) {
+    output_shape.AddDim(I20);
+  }
+  output_shape.AddDim(I21);
+  if (indices1_dims == 2) {
+    output_shape.AddDim(I11);
+  }
+  output_shape.AddDim(P_row);
+
+  OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, output_shape, output));
+  VLOG(1) << "fused gather output shape: " << output->shape().DebugString();
+
+  float* output_data = output->flat<float>().data();
+
+  // Fused double gather (parallelized over the outer I20 * I21 work units):
+  // 2D+2D: output[i, j, k, :] = params[indices1[indices2[i, j], k], :]
+  // 1D+2D: output[i, j, :] = params[indices1[indices2[i, j]], :]
+  // 2D+1D: output[j, k, :] = params[indices1[indices2[j], k], :]
+  // 1D+1D: output[j, :] = params[indices1[indices2[j]], :]
+  //
+  // Each work unit (i, j) writes to a disjoint region of output, so no
+  // synchronization is needed between parallel tasks.
+  // cost_per_unit: 2D indices1 -> I11 * P_row floats copied; 1D -> P_row floats.
+  const int64_t total_units = static_cast<int64_t>(I20) * I21;
+  const int64_t cost_per_unit =
+      (indices1_dims == 2) ? static_cast<int64_t>(I11) * P_row
+                           : P_row;
+
+  auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+  worker_threads->workers->ParallelFor(
+      total_units, cost_per_unit,
+      [&](int64_t begin, int64_t end) {
+        for (int64_t flat = begin; flat < end; ++flat) {
+          const int i = static_cast<int>(flat / I21);
+          const int j = static_cast<int>(flat % I21);
+
+          Tindices2 idx2 = indices2_data[(indices2_dims == 2) ? (i * I21 + j) : j];
+
+          if (indices1_dims == 2) {
+            // 2D indices1: idx2 indexes into first dimension (I10)
+            if (TF_PREDICT_FALSE(idx2 < 0 || idx2 >= I10)) {
+              context->CtxFailure(errors::InvalidArgument(
+                  "FusedGather: indices2[",
+                  (indices2_dims == 2) ? i : 0, ",", j, "]=", idx2,
+                  " out of range [0, ", I10, ")"));
+              return;
+            }
+            for (int k = 0; k < I11; ++k) {
+              Tindices1 idx1 = indices1_data[idx2 * I11 + k];
+              if (TF_PREDICT_FALSE(idx1 < 0 || idx1 >= P0)) {
+                context->CtxFailure(errors::InvalidArgument(
+                    "GatherV2 axis=0: index out of range"));
+                return;
+              }
+              int64_t output_offset;
+              if (indices2_dims == 2) {
+                // 2D+2D: [i, j, k, :]
+                output_offset = ((i * I21 + j) * I11 + k) * P_row;
+              } else {
+                // 2D+1D: [j, k, :]
+                output_offset = (j * I11 + k) * P_row;
+              }
+              std::memcpy(
+                  output_data + output_offset,
+                  params_data + idx1 * P_row,
+                  sizeof(float) * P_row
+              );
+            }
+          } else {
+            // 1D indices1: idx2 indexes into indices1, then idx1 indexes into params
+            if (TF_PREDICT_FALSE(idx2 < 0 || idx2 >= I10)) {
+              context->CtxFailure(errors::InvalidArgument(
+                  "FusedGather: indices2[",
+                  (indices2_dims == 2) ? i : 0, ",", j, "]=", idx2,
+                  " out of range [0, ", I10, ")"));
+              return;
+            }
+            Tindices1 idx1 = indices1_data[idx2];
+            if (TF_PREDICT_FALSE(idx1 < 0 || idx1 >= P0)) {
+              context->CtxFailure(errors::InvalidArgument(
+                  "GatherV2 axis=0: index out of range"));
+              return;
+            }
+            int64_t output_offset;
+            if (indices2_dims == 2) {
+              // 1D+2D: [i, j, :]
+              output_offset = (i * I21 + j) * P_row;
+            } else {
+              // 1D+1D: [j, :]
+              output_offset = j * P_row;
+            }
+            std::memcpy(
+                output_data + output_offset,
+                params_data + idx1 * P_row,
+                sizeof(float) * P_row
+            );
+          }
+        }
+      });
+}
+
+
+template <typename Tindices1, typename Tindices2>
+class KPFusedEmbeddingActionIdGatherOp : public OpKernel {
+public:
+  explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    // Grab the input tensor
+    const Tensor& indices1 = context->input(0);
+    const Tensor& params = context->input(1);
+    const Tensor& indices2 = context->input(2);
+    const Tensor& pack_dim = context->input(3);
+
+    const Tensor& pack = context->input(4);
+
+    VLOG(1) << "indices1 shape: " << indices1.shape().DebugString();
+    VLOG(1) << "params shape: " << params.shape().DebugString();
+    VLOG(1) << "indices2 shape: " << indices2.shape().DebugString();
+    OP_REQUIRES(
+      context,
+      indices1.dims() >= 1 && indices1.dims() <= 2,
+      errors::InvalidArgument("indices1 dims must be 1 or 2")
+    );
+    OP_REQUIRES(
+      context,
+      indices2.dims() >= 1 && indices2.dims() <= 2,
+      errors::InvalidArgument("indices2 dims must be 1 or 2")
+    );
+    OP_REQUIRES(
+      context,
+      params.dims() >= 2 && params.dims() <= 3,
+      errors::InvalidArgument("params dims must = 2 or 3")
+    );
+    OP_REQUIRES(
+      context,
+      TensorShapeUtils::IsScalar(pack_dim.shape()),
+      errors::InvalidArgument("pack_dim is scalar")
+    );
+    OP_REQUIRES(
+      context,
+      TensorShapeUtils::IsScalar(pack.shape()),
+      errors::InvalidArgument("pack const is scalar")
+    );
+
+    // Fused double gather: directly compute params[indices1[indices2[i]]]
+    Tensor gathered;
+    FusedDoubleGatherImpl<Tindices1, Tindices2>(
+        context,
+        params.flat<float>().data(), params.shape(),
+        indices1.flat<Tindices1>().data(), indices1.shape(),
+        indices2.flat<Tindices2>().data(), indices2.shape(),
+        &gathered);
+    int pack_size = pack_dim.scalar<int32>()();
+    int pack_const = pack.scalar<int32>()();
+    OP_REQUIRES(context, pack_size > 0, errors::InvalidArgument("pack_size must > 0"));
+    int a_reshaped_cols = gathered.NumElements() / pack_size;
+    auto a_reshaped = gathered.shaped<float, 2>({pack_size, a_reshaped_cols});
+    Tensor* output;
+    int output_cols = a_reshaped_cols + pack_const;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({pack_size, output_cols}), &output));
+    auto a_reshaped_data = a_reshaped.data();
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    const int64_t pack_cost_per_row =
+        static_cast<int64_t>(a_reshaped_cols) + pack_const;
+    worker_threads->workers->ParallelFor(
+        pack_size, pack_cost_per_row,
+        [&](int64_t start_row, int64_t end_row) {
+          float* base = output->matrix<float>().data();
+          for (int64_t row = start_row; row < end_row; ++row) {
+            float* dst_row = base + row * (a_reshaped_cols + pack_const);
+            std::memcpy(
+                dst_row, a_reshaped_data + row * a_reshaped_cols,
+                sizeof(float) * a_reshaped_cols
+            );
+            std::memset(
+                dst_row + a_reshaped_cols, 0, sizeof(float) * pack_const
+            );
+          }
+        });
+  }
+};
+
+#define REGISTER_CPU_KERNEL(Tindices1, Tindices2)                                \
+  REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingActionIdGather") \
+                              .Device(DEVICE_CPU)            \
+                              .TypeConstraint<Tindices1>("Tindices1") \
+                              .TypeConstraint<Tindices2>("Tindices2"), \
+                          KPFusedEmbeddingActionIdGatherOp<Tindices1, Tindices2>);
+
+REGISTER_CPU_KERNEL(int64, int32)
+REGISTER_CPU_KERNEL(int32, int32)
+REGISTER_CPU_KERNEL(int64, int64)
+REGISTER_CPU_KERNEL(int32, int64)
+
+}
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_action_id_gather_test.cc b/tensorflow/core/kernels/embedding_fused_action_id_gather_test.cc
new file mode 100644
index 000000000..16d96eff3
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_action_id_gather_test.cc
@@ -0,0 +1,369 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *     Unless required by applicable law or agreed to in writing, software
+ *     distributed under the License is distributed on an "AS IS" BASIS,
+ *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *     See the License for the specific language governing permissions and
+ *     limitations under the License.
+ *     ==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+class KPFusedEmbeddingActionIdGatherTest : public OpsTestBase {
+ protected:
+  void MakeOp(DataType indices1_type, DataType indices2_type) {
+    TF_ASSERT_OK(NodeDefBuilder("fused_embedding_action_id_gather",
+                                "KPFusedEmbeddingActionIdGather")
+                     .Input(FakeInput(indices1_type))  // indices1
+                     .Input(FakeInput(DT_FLOAT))       // params
+                     .Input(FakeInput(indices2_type))  // indices2
+                     .Input(FakeInput(DT_INT32))       // pack_dim
+                     .Input(FakeInput(DT_INT32))       // pack
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+  template <typename Tindices1, typename Tindices2>
+  Status FeedAndRun(const std::vector<Tindices1>& indices1_data,
+                    const TensorShape& indices1_shape,
+                    const std::vector<float>& params_data,
+                    const TensorShape& params_shape,
+                    const std::vector<Tindices2>& indices2_data,
+                    const TensorShape& indices2_shape, int pack_dim_value,
+                    int pack_value) {
+    inputs_.clear();
+    input_types_.clear();
+
+    MakeOp(DataTypeToEnum<Tindices1>::v(), DataTypeToEnum<Tindices2>::v());
+    AddInputFromArray<Tindices1>(indices1_shape, indices1_data);
+    AddInputFromArray<float>(params_shape, params_data);
+    AddInputFromArray<Tindices2>(indices2_shape, indices2_data);
+    AddInputFromArray<int32>(TensorShape({}), {pack_dim_value});
+    AddInputFromArray<int32>(TensorShape({}), {pack_value});
+    return RunOpKernel();
+  }
+};
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, NormalCase) {
+  std::vector<int64> indices1_data = {0, 2};
+  TensorShape indices1_shape({2, 1});
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  TensorShape params_shape({3, 2});
+
+  std::vector<int32> indices2_data = {1, 0};
+  TensorShape indices2_shape({2, 1});
+
+  int pack_dim_value = 2;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+  test::FillValues<float>(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, OneDIndices1AndOneDIndices2) {
+  std::vector<int64> indices1_data = {0, 1, 2};
+  TensorShape indices1_shape({3});
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  TensorShape params_shape({3, 2});
+
+  std::vector<int32> indices2_data = {1, 0};
+  TensorShape indices2_shape({2});
+
+  int pack_dim_value = 2;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+  test::FillValues<float>(&expected, {3.0f, 4.0f, 0.0f, 1.0f, 2.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, OneDIndices1AndTwoDIndices2) {
+  std::vector<int64> indices1_data = {0, 2, 1};
+  TensorShape indices1_shape({3});
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  TensorShape params_shape({3, 2});
+
+  std::vector<int32> indices2_data = {1, 0, 2, 1};
+  TensorShape indices2_shape({2, 2});
+
+  int pack_dim_value = 4;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
+  test::FillValues<float>(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f,
+                                       3.0f, 4.0f, 0.0f, 5.0f, 6.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, TwoDIndices1AndOneDIndices2) {
+  std::vector<int64> indices1_data = {0, 2, 1, 0};
+  TensorShape indices1_shape({2, 2});
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  TensorShape params_shape({3, 2});
+
+  std::vector<int32> indices2_data = {1, 0};
+  TensorShape indices2_shape({2});
+
+  int pack_dim_value = 2;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+  test::FillValues<float>(&expected, {3.0f, 4.0f, 1.0f, 2.0f, 0.0f,
+                                       1.0f, 2.0f, 5.0f, 6.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, DifferentIndexTypes) {
+  // int64int32
+  {
+    std::vector<int64> indices1 = {0, 2};
+    std::vector<int32> indices2 = {1, 0};
+    TF_ASSERT_OK((FeedAndRun<int64, int32>(indices1, {2, 1},
+                                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+                                           {3, 2}, indices2, {2, 1}, 2, 1)));
+    test::ExpectTensorNear<float>(
+        *GetOutput(0),
+        test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+        1e-5);
+  }
+
+  // int32int32
+  {
+    std::vector<int32> indices1 = {0, 2};
+    std::vector<int32> indices2 = {1, 0};
+    TF_ASSERT_OK((FeedAndRun<int32, int32>(indices1, {2, 1},
+                                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+                                           {3, 2}, indices2, {2, 1}, 2, 1)));
+    test::ExpectTensorNear<float>(
+        *GetOutput(0),
+        test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+        1e-5);
+  }
+
+  // int64int64
+  {
+    std::vector<int64> indices1 = {0, 2};
+    std::vector<int64> indices2 = {1, 0};
+    TF_ASSERT_OK((FeedAndRun<int64, int64>(indices1, {2, 1},
+                                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+                                           {3, 2}, indices2, {2, 1}, 2, 1)));
+    test::ExpectTensorNear<float>(
+        *GetOutput(0),
+        test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+        1e-5);
+  }
+
+  // int32int64
+  {
+    std::vector<int32> indices1 = {0, 2};
+    std::vector<int64> indices2 = {1, 0};
+    TF_ASSERT_OK((FeedAndRun<int32, int64>(indices1, {2, 1},
+                                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+                                           {3, 2}, indices2, {2, 1}, 2, 1)));
+    test::ExpectTensorNear<float>(
+        *GetOutput(0),
+        test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
+        1e-5);
+  }
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidParamsDims) {
+  MakeOp(DT_INT64, DT_INT32);
+
+  std::vector<int64> indices1_data = {0, 2};
+  AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f};
+  AddInputFromArray<float>(TensorShape({4}), params_data);
+
+  std::vector<int32> indices2_data = {1, 0};
+  AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+  AddInputFromArray<int32>(TensorShape({}), {2});
+  AddInputFromArray<int32>(TensorShape({}), {1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.ToString(), "params dims must = 2 or 3")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDimDims) {
+  MakeOp(DT_INT64, DT_INT32);
+
+  std::vector<int64> indices1_data = {0, 2};
+  AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+  std::vector<int32> indices2_data = {1, 0};
+  AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+  AddInputFromArray<int32>(TensorShape({1}), {2});
+  AddInputFromArray<int32>(TensorShape({}), {1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_dim is scalar")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDims) {
+  MakeOp(DT_INT64, DT_INT32);
+
+  std::vector<int64> indices1_data = {0, 2};
+  AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+  std::vector<int32> indices2_data = {1, 0};
+  AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+  AddInputFromArray<int32>(TensorShape({}), {2});
+  AddInputFromArray<int32>(TensorShape({1}), {1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.ToString(), "pack const is scalar")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackSize) {
+  MakeOp(DT_INT64, DT_INT32);
+
+  std::vector<int64> indices1_data = {0, 2};
+  AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+  std::vector<int32> indices2_data = {1, 0};
+  AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+  AddInputFromArray<int32>(TensorShape({}), {0});
+  AddInputFromArray<int32>(TensorShape({}), {1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_size must > 0")) << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, IndexOutOfRange) {
+  MakeOp(DT_INT64, DT_INT32);
+
+  std::vector<int64> indices1_data = {0, 5};
+  AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+  AddInputFromArray<float>(TensorShape({3, 2}), params_data);
+
+  std::vector<int32> indices2_data = {1, 0};
+  AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
+
+  AddInputFromArray<int32>(TensorShape({}), {2});
+  AddInputFromArray<int32>(TensorShape({}), {1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(
+      absl::StrContains(s.ToString(), "GatherV2 axis=0: index out of range"))
+      << s;
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, ThreeDParams2DIndices1And2DIndices2) {
+  std::vector<int64> indices1_data = {0, 1};
+  TensorShape indices1_shape({2, 1});
+
+  // params[0] = [[1,2],[3,4]], params[1] = [[5,6],[7,8]], stored row-major
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f,
+                                     5.0f, 6.0f, 7.0f, 8.0f};
+  TensorShape params_shape({2, 2, 2});
+
+  std::vector<int32> indices2_data = {1, 0};
+  TensorShape indices2_shape({2, 1});
+
+  int pack_dim_value = 2;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+  test::FillValues<float>(&expected,
+                           {5.0f, 6.0f, 7.0f, 8.0f, 0.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingActionIdGatherTest, ThreeDParams1DIndices1And1DIndices2) {
+  std::vector<int64> indices1_data = {0, 1, 2};
+  TensorShape indices1_shape({3});
+
+  std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f,
+                                     5.0f, 6.0f, 7.0f, 8.0f,
+                                     9.0f, 10.0f, 11.0f, 12.0f};
+  TensorShape params_shape({3, 2, 2});
+
+  std::vector<int32> indices2_data = {2, 0};
+  TensorShape indices2_shape({2});
+
+  int pack_dim_value = 2;
+  int pack_value = 1;
+
+  TF_ASSERT_OK((FeedAndRun<int64, int32>(
+      indices1_data, indices1_shape, params_data, params_shape, indices2_data,
+      indices2_shape, pack_dim_value, pack_value)));
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 5}));
+  test::FillValues<float>(&expected,
+                           {9.0f, 10.0f, 11.0f, 12.0f, 0.0f,
+                            1.0f, 2.0f, 3.0f, 4.0f, 0.0f});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc
new file mode 100644
index 000000000..14de30e35
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_gather.cc
@@ -0,0 +1,90 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+class KPFusedGather : public OpKernel {
+ public:
+  explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& data = context->input(0);
+    const Tensor& keys = context->input(1);
+    const Tensor& begin = context->input(2);
+    VLOG(1) << "Embedding table size: " << data.shape().DebugString();
+    VLOG(1) << "Input key shape: " << keys.shape().DebugString();
+    VLOG(1) << "Slice begin value: " << begin.DebugString();
+
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsMatrix(keys.shape()), 
+                errors::Internal("Input key must be 2D"));
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsMatrix(data.shape()),
+                errors::Internal("Embedding table shape must be 2D"));
+    OP_REQUIRES(context, begin.NumElements() == 2, errors::Internal("begin must be same as keys rank"));
+    int32 col = begin.flat<int32>().data()[1];
+    OP_REQUIRES(context, col < keys.dim_size(1), errors::Internal("slice cols out of keys range"));
+    
+    Tensor* out_indices = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                   1, TensorShape({static_cast<int32>(keys.dim_size(0))}), &out_indices));
+    int32 *out_indices_data = out_indices->flat<int32>().data();
+
+    auto keys_mat = keys.matrix<int64>();
+    std::vector<int64_t> unique_values;
+    std::unordered_map<int64_t, int32_t> value_to_index;
+    int current_index = 0;
+    for (int64_t i = 0; i < keys.dim_size(0); ++i) {
+        auto it = value_to_index.find(keys_mat(i, col));
+        if (it == value_to_index.end()) {
+            value_to_index[keys_mat(i, col)] = current_index;
+            unique_values.push_back(keys_mat(i, col));
+            out_indices_data[i] = current_index;
+            ++current_index;
+        } else {
+            out_indices_data[i] = it->second;
+        }
+    }
+
+    Tensor* out_unique_value = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                   0, TensorShape({static_cast<int32>(unique_values.size())}), &out_unique_value));
+    std::memcpy(out_unique_value->data(), unique_values.data(), unique_values.size() * sizeof(int64_t));
+
+    Tensor* out_data = nullptr;
+    int embedding_dims = data.dim_size(1);
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                   2, TensorShape({static_cast<int32>(unique_values.size()), embedding_dims}), &out_data));
+
+    const float *data_mat = data.flat<float>().data();
+    for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) {
+        int64_t idx = unique_values[cur_row];
+        OP_REQUIRES(context, idx < data.dim_size(0) && idx >= 0, errors::Internal("idx out of table range"));
+        const float* src = data_mat + idx * embedding_dims;
+        float* dst = out_data->flat<float>().data() + cur_row * embedding_dims;
+        std::memcpy(dst, src, embedding_dims * sizeof(float));
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedGather").Device(DEVICE_CPU),
+                        KPFusedGather);
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_gather_test.cc b/tensorflow/core/kernels/embedding_fused_gather_test.cc
new file mode 100644
index 000000000..c94770910
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_gather_test.cc
@@ -0,0 +1,186 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::ExpectClose;
+using tensorflow::test::FillValues;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedGatherTest : public OpsTestBase {
+  protected:
+    void RunValidCase(const TensorShape& data_shape,
+                      const TensorShape& slice_shape,
+                      const std::vector<int32>& begin_val,
+                      const std::vector<int64>& slice_data,
+                      const std::vector<float>& data_data,
+                      const std::vector<int64>& expected_unique,
+                      const std::vector<int32>& expected_indices,
+                      const std::vector<float>& expected_output_data) {
+      TF_EXPECT_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
+                       .Input(FakeInput(DT_FLOAT))
+                       .Input(FakeInput(DT_INT64))
+                       .Input(FakeInput(DT_INT32))
+                       .Finalize(node_def()));
+      TF_EXPECT_OK(InitOp());
+
+      AddInputFromArray<float>(data_shape, data_data);
+      AddInputFromArray<int64>(slice_shape, slice_data);
+      AddInputFromArray<int32>(TensorShape({2}), begin_val);
+
+      TF_ASSERT_OK(RunOpKernel());
+      
+      const Tensor& out_unique = *GetOutput(0);
+      const Tensor& out_indices = *GetOutput(1);
+      const Tensor& out_data = *GetOutput(2);
+
+      // 验证输出0: unique_values
+      Tensor expected_unique_tensor(
+        allocator(), DT_INT64,
+        TensorShape({static_cast<int64>(expected_unique.size())})
+      );
+      FillValues<int64>(&expected_unique_tensor, expected_unique);
+      ExpectTensorEqual<int64>(expected_unique_tensor, out_unique);
+
+      // 验证输出1: indices
+      Tensor expected_indices_tensor(
+        allocator(), DT_INT32,
+        TensorShape({static_cast<int64_t>(expected_indices.size())})
+      );
+      FillValues<int32>(&expected_indices_tensor, expected_indices);
+      ExpectTensorEqual<int32>(expected_indices_tensor, out_indices); 
+
+      // 验证输出2: out_data
+      Tensor expected_data_tensor(allocator(), DT_FLOAT,
+                                TensorShape({static_cast<int64>(expected_unique.size()), 12}));
+      FillValues<float>(&expected_data_tensor, expected_output_data);
+      ExpectClose(expected_data_tensor, out_data);  // float 用 ExpectClose
+    }
+
+    Status RunOpExpectFailure(const TensorShape& data_shape,
+                              const TensorShape& slice_shape,
+                              const std::vector<int32>& begin_val,
+                              const std::vector<int64>& slice_data,
+                              const std::vector<float>& data_data) {
+      TF_CHECK_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
+                      .Input(FakeInput(DT_FLOAT))
+                      .Input(FakeInput(DT_INT64))
+                      .Input(FakeInput(DT_INT32))
+                      .Finalize(node_def()));
+      TF_CHECK_OK(InitOp());
+
+      AddInputFromArray<float>(data_shape, data_data);
+      AddInputFromArray<int64>(slice_shape, slice_data);
+      AddInputFromArray<int32>(TensorShape({2}), begin_val);
+      
+      return RunOpKernel();
+    }
+};
+
+// 正向测试:正常输入
+TEST_F(KPFusedGatherTest, Valid_NormalInput) {
+  RunValidCase(
+      TensorShape({2, 12}),               // data shape
+      TensorShape({4, 3}),                // slice_input shape
+      {0, 1},                             // begin[1] = 1 → 取第1列
+      {1, 1, 3,
+       0, 1, 5,
+       1, 0, 7,
+       0, 1, 9},                          // slice_input 数据
+      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+       13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f},
+      {1, 0},                             // unique values from col=1
+      {0, 0, 1, 0},                       // indices mapping
+      {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,   // data[1]
+       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}            // data[0]
+  );
+}
+
+// data不是2维
+TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) {
+  std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f};
+  Status s = RunOpExpectFailure(
+      TensorShape({4}),         // data 不是二维
+      TensorShape({2, 2}),
+      {0, 0},
+      {0, 1, 2, 3},
+      data
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Embedding table shape must be 2D"));
+}
+
+// key 不是2维
+TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) {
+  std::vector<float> data(2 * 12, 1.0f);
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 12}),
+      TensorShape({4}),         // 1D slice_input
+      {0, 0},
+      {0, 1, 2, 3},
+      data
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input key must be 2D"));
+}
+
+// begin[1] 超出列范围
+TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) {
+  std::vector<float> data(2 * 12, 1.0f);
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 12}),
+      TensorShape({2, 2}),
+      {0, 2},                      // begin[1] = 2,但只有 2 列 → 索引 0,1
+      {0, 1, 2, 3},
+      data
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(),"slice cols out of keys range"));
+}
+
+// gather 索引超出 data 行数
+TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) {
+  std::vector<float> data(2 * 12, 1.0f);
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 12}),
+      TensorShape({2, 2}),
+      {0, 0},
+      {0, 1,
+       2, 3},  // 索引 2 超出 data 行数(只有 0,1)
+      data
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(),"idx out of table range"));
+}
+
+}
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_padding.cc b/tensorflow/core/kernels/embedding_fused_padding.cc
new file mode 100644
index 000000000..91f775e68
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_padding.cc
@@ -0,0 +1,126 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+class KPFusedEmbeddingPaddingOp : public OpKernel {
+public:
+  explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) {
+    fast_ = (type_string() == "KPFusedEmbeddingPaddingFast");
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Grab the input tensor
+    const Tensor& origin_shape = context->input(0);
+    const Tensor& input = context->input(1);
+    const Tensor& input_rows = context->input(2);
+    const Tensor& reshape_sizes = context->input(3);
+
+    const Tensor& pack = context->input(4);
+
+    VLOG(1) << "Input shape: " << input.shape().DebugString();
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsVector(origin_shape.shape()), 
+                errors::InvalidArgument("origin_shape dims must 1D, not ", origin_shape.shape().DebugString())
+    );
+    OP_REQUIRES(context,
+                origin_shape.NumElements() == 2,
+                errors::InvalidArgument("origin_shape NumElements must == 2, not ", origin_shape.NumElements())
+    );
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsMatrix(input.shape()),
+                errors::InvalidArgument("input dims must 2D, not ", input.shape().DebugString()));
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsScalar(input_rows.shape()),
+                errors::InvalidArgument("input_rows must be a scalar")
+    );
+    OP_REQUIRES(context,
+                TensorShapeUtils::IsVector(reshape_sizes.shape()),
+                errors::InvalidArgument("sizes input must be 1-D, not ", reshape_sizes.shape().DebugString())
+    );
+    OP_REQUIRES(context, 
+                reshape_sizes.NumElements() == 2,
+                errors::InvalidArgument("reshape_sizes NumElements must == 2"));
+
+    int input_rows_value = input_rows.scalar<int32>()();
+    int padding_rows = static_cast<int32>(origin_shape.flat<int64>()(0)) - input_rows_value;
+    auto reshape_cols = reshape_sizes.flat<int32>()(1);
+    int output_rows = padding_rows + input.dim_size(0);
+    int output_cols = input.dim_size(1);
+    OP_REQUIRES(context,
+                padding_rows >= 0,
+                errors::InvalidArgument("Pooling size(", input_rows_value,
+                ") is greater than Input size(", static_cast<int32>(origin_shape.flat<int64>()(0)), ")"));
+    OP_REQUIRES(context,
+                reshape_cols > 0,
+                errors::InvalidArgument("reshape_cols must > 0"));
+    OP_REQUIRES(context,
+                reshape_sizes.flat<int32>()(0) == -1,
+                errors::InvalidArgument("reshape[0] is not -1"));
+    OP_REQUIRES(context,
+                pack.scalar<int32>()() == output_cols,
+                errors::InvalidArgument("pack(", pack.scalar<int32>()(), ") is not equal to embedding dims"));
+
+    Tensor* output0 = nullptr;
+    Tensor* output1 = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({}),
+                                            &output0));
+    output0->scalar<int32>()() = padding_rows;
+    OP_REQUIRES(context,
+                output_rows * output_cols % reshape_cols == 0,
+                errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]")
+    );
+    int reshape_rows = output_rows * output_cols / reshape_cols;
+    if (fast_) {
+      OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1));
+      output1->scalar<int32>()() = reshape_rows;
+      return;
+    }
+
+    TensorShape reshaped_shape({reshape_rows, reshape_cols});
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, reshaped_shape, &output1));
+    float* output_data = output1->flat<float>().data();
+    const float* input_data = input.flat<float>().data();
+    std::memcpy(output_data, input_data, input.dim_size(0) * output_cols * sizeof(float));
+    std::memset(output_data + input.dim_size(0) * output_cols,
+                0.0f,
+                padding_rows * output_cols * sizeof(float));
+  }
+
+private:
+    bool fast_;
+};
+
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPadding").Device(DEVICE_CPU),
+                        KPFusedEmbeddingPaddingOp);
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPaddingFast").Device(DEVICE_CPU),
+                        KPFusedEmbeddingPaddingOp);
+
+}
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_padding_test.cc b/tensorflow/core/kernels/embedding_fused_padding_test.cc
new file mode 100644
index 000000000..5137d5130
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_padding_test.cc
@@ -0,0 +1,307 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+class KPFusedEmbeddingPaddingTest : public OpsTestBase {
+ protected:
+  void MakeOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
+    TF_ASSERT_OK(NodeDefBuilder("fused_padding", "KPFusedEmbeddingPadding")
+                     .Input(FakeInput(input_shape_type))
+                     .Input(FakeInput(pooling_type))
+                     .Input(FakeInput(const_type))
+                     .Input(FakeInput(reshape_type))
+                     .Input(FakeInput(const_type))
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+  Status FeedAndRun(const int embedding_dims, const int table_size,
+                    const int pooling_size, const int reshape_size) {
+    MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+    AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+    AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { 
+      return static_cast<float>(i + 1); 
+    });
+    AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+    AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+    AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+    return RunOpKernel();
+  }
+
+  void MakeFastOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
+    TF_ASSERT_OK(NodeDefBuilder("fused_padding_fast", "KPFusedEmbeddingPaddingFast")
+                     .Input(FakeInput(input_shape_type))
+                     .Input(FakeInput(pooling_type))
+                     .Input(FakeInput(const_type))
+                     .Input(FakeInput(reshape_type))
+                     .Input(FakeInput(const_type))
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+
+  Status FeedAndRunFast(const int embedding_dims, const int table_size,
+                        const int pooling_size, const int reshape_size) {
+    MakeFastOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+    AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+    AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { 
+      return static_cast<float>(i + 1); 
+    });
+    AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+    AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+    AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+    return RunOpKernel();
+  }
+};
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_0) {
+  // Feed and run
+  const int embedding_dims = 10;
+  const int table_size = 151;
+  const int pooling_size = 151;
+  const int reshape_size = 1510;
+  TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillFn<float>(&expected2, [=](int i) -> float { 
+    if (i < pooling_size * embedding_dims) {
+      return static_cast<float>(i + 1); 
+    } else {
+      return 0.0f;
+    }
+  });
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_1) {
+  // Feed and run
+  const int embedding_dims = 10;
+  const int table_size = 1510;
+  const int pooling_size = 151;
+  const int reshape_size = 1510;
+  TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillFn<float>(&expected2, [=](int i) -> float { 
+    if (i < pooling_size * embedding_dims) {
+      return static_cast<float>(i + 1); 
+    } else {
+      return 0.0f;
+    }
+  });
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_0) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 2;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillFn<float>(&expected2, [=](int i) -> float { 
+    if (i < pooling_size * embedding_dims) {
+      return static_cast<float>(i + 1); 
+    } else {
+      return 0.0f;
+    }
+  });
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_1) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 200;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillFn<float>(&expected2, [=](int i) -> float { 
+    if (i < pooling_size * embedding_dims) {
+      return static_cast<float>(i + 1); 
+    } else {
+      return 0.0f;
+    }
+  });
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_0) {
+  // Feed and run
+  const int embedding_dims = 10;
+  const int table_size = 151;
+  const int pooling_size = 151;
+  const int reshape_size = 1510;
+  TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_1) {
+  // Feed and run
+  const int embedding_dims = 10;
+  const int table_size = 1510;
+  const int pooling_size = 151;
+  const int reshape_size = 1510;
+  TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_0) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 2;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_1) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 200;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
+
+  // Check the output.
+  Tensor expected1(allocator(), DT_INT32, TensorShape({}));
+  Tensor expected2(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected1, {table_size - pooling_size});
+  test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
+  test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
+  test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectReshape) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 200;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+  AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+  AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { 
+    return static_cast<float>(i + 1); 
+  });
+  AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+  AddInputFromArray<int32>(TensorShape({2}), {10, reshape_size});
+  AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+  Status s = RunOpKernel();
+  EXPECT_TRUE(
+      absl::StrContains(s.ToString(), "reshape[0] is not -1"))
+      << s;
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectPack) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 200;
+  const int pooling_size = 2;
+  const int reshape_size = 24;
+  MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+  AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+  AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { 
+    return static_cast<float>(i + 1); 
+  });
+  AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+  AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+  AddInputFromArray<int32>(TensorShape({}), {10});
+  Status s = RunOpKernel();
+  EXPECT_TRUE(
+      absl::StrContains(s.ToString(), "pack(10) is not equal to embedding dims"))
+      << s;
+}
+
+TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithPoolingSizeGreaterInput) {
+  // Feed and run
+  const int embedding_dims = 12;
+  const int table_size = 200;
+  const int pooling_size = 201;
+  const int reshape_size = 24;
+  MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
+  AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
+  AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { 
+    return static_cast<float>(i + 1); 
+  });
+  AddInputFromArray<int32>(TensorShape({}), {pooling_size});
+  AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
+  AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
+  Status s = RunOpKernel();
+  EXPECT_TRUE(
+      absl::StrContains(s.ToString(), "Pooling size(201) is greater than Input size(200)"))
+      << s;
+}
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc
new file mode 100644
index 000000000..05c7b2fa9
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc
@@ -0,0 +1,87 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+class KPFusedSparseDynamicStitchOp : public OpKernel {
+public:
+  explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& x = context->input(0);
+    auto x_flat = x.flat<int64>();
+    int64_t num_elems = x_flat.size();
+
+    const int num_inputs = context->num_inputs();
+    const int num_partitions = num_inputs - 1;
+    OP_REQUIRES(context, num_partitions > 1, errors::InvalidArgument("num partitions must > 1"));
+    int64_t output_stride = 0;
+    std::vector<const float*> variables(num_partitions);
+    std::vector<int64_t> variable_rows(num_partitions);
+    for (int i = 1; i < num_inputs; ++i) {
+      const Tensor& input_tensor = context->input(i);
+      OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input dims must == 2"));
+      if (i == 1) {
+        output_stride = input_tensor.dim_size(1);
+      } else {
+        OP_REQUIRES(context, input_tensor.dim_size(1)  == output_stride,
+                    errors::InvalidArgument("All inputs must have same second dimension"));
+      }
+      variables[i - 1] = context->input(i).flat<float>().data();
+      variable_rows[i - 1] = input_tensor.dim_size(0);
+    }
+
+    OP_REQUIRES(context, output_stride > 0, errors::InvalidArgument("output_stride must > 0"));
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({num_elems, output_stride}),
+                                            &output_tensor));
+    float* output = (float*)output_tensor->tensor_data().data();
+
+    const size_t copy_size = output_stride * sizeof(float);
+
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    const int64 cost_per_unit = 120; // Actual single cycle execution time 
+    auto work = [&](int start, int end) {
+      for (int i = start; i < end; ++i) {
+        const int64_t global_id = x_flat(i);
+        const int64_t table_id = global_id % num_partitions;
+        const int64_t row_id = global_id / num_partitions;
+
+        OP_REQUIRES(context, row_id < variable_rows[table_id] && row_id >= 0, errors::InvalidArgument(
+          "row_id out of range."));
+
+        std::memcpy(output + i * output_stride,
+                    variables[table_id] + row_id * output_stride, copy_size);
+      }
+    };
+
+    Shard(worker_threads->num_threads, worker_threads->workers, num_elems,
+          cost_per_unit, work);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU),
+                        KPFusedSparseDynamicStitchOp);
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch_test.cc
new file mode 100644
index 000000000..74fdd2503
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch_test.cc
@@ -0,0 +1,108 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *     Unless required by applicable law or agreed to in writing, software
+ *     distributed under the License is distributed on an "AS IS" BASIS,
+ *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *     See the License for the specific language governing permissions and
+ *     limitations under the License.
+ *     ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseDynamicStitchOpTest : public OpsTestBase {
+ protected:
+  void MakeOp(int N) {
+    TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_dynamic_stitch",
+                                "KPFusedSparseDynamicStitch")
+                     .Input(FakeInput(DT_INT64))
+                     .Input(FakeInput(N, DT_FLOAT))
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+};
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestTwoTables) {
+  MakeOp(2);  // num_partitions = 2
+
+  AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
+  AddInputFromArray<float>(TensorShape({3, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+  AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+  test::FillValues<float>(&expected,
+                          {1.0f, 2.0f, 9.0f, 10.0f, 3.0f, 4.0f, 7.0f, 8.0f});
+  test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestDifferentStride) {
+  MakeOp(2);
+
+  AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
+  AddInputFromArray<float>(TensorShape({3, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+  AddInputFromArray<float>(TensorShape({1, 4}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(),"All inputs must have same second dimension"));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestIndicesOutOfBounds) {
+  MakeOp(2);
+
+  AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
+  AddInputFromArray<float>(TensorShape({3, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+  AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(),"row_id out of range"));
+}
+
+TEST_F(KPFusedSparseDynamicStitchOpTest, TestInputDims) {
+  MakeOp(2);
+
+  AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
+  AddInputFromArray<float>(TensorShape({3, 2, 1}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+  AddInputFromArray<float>(TensorShape({2, 2, 1}), {7.0f, 8.0f, 9.0f, 10.0f});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(),"input dims must == 2"));
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc
new file mode 100644
index 000000000..9b03e429b
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc
@@ -0,0 +1,203 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "tensorflow/core/kernels/reshape_util.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+using namespace tensorflow;
+
+static void ReshapeKp(OpKernelContext *context, const Tensor &input_indices_in,
+             const Tensor &input_shape_in, const Tensor &target_shape_in,
+             int output_indices_idx, int output_shape_idx) {
+  OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
+              errors::InvalidArgument(
+                  "Input indices should be a matrix but received shape ",
+                  input_indices_in.shape().DebugString()));
+  OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
+              errors::InvalidArgument(
+                  "Input shape should be a vector but received shape ",
+                  input_shape_in.shape().DebugString()));
+  OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
+              errors::InvalidArgument(
+                  "Target shape should be a vector but received shape ",
+                  target_shape_in.shape().DebugString()));
+
+  const int64 input_rank = input_shape_in.NumElements();
+  const int64 output_rank = target_shape_in.NumElements();
+  const TensorShape input_shape(input_shape_in.vec<int64>());
+  const int64 dense_size = input_shape.num_elements();
+  const int64 nnz = input_indices_in.shape().dim_size(0);
+
+  TensorShape output_shape;
+  int64 product = 1;
+  int unknown_index = -1;
+  auto target_shape = target_shape_in.vec<int64>();
+  for (int d = 0; d < output_rank; ++d) {
+    const int64 size = target_shape(d);
+    if (size == -1) {
+      OP_REQUIRES(
+          context, unknown_index == -1,
+          errors::InvalidArgument("only one output dimension may be -1, "
+                                  "not both ",
+                                  unknown_index, " and ", d));
+      unknown_index = d;
+      output_shape.AddDim(1);
+    } else {
+      OP_REQUIRES(context, size >= 0,
+                  errors::InvalidArgument("size ", d,
+                                          " must be non-negative, not ", size));
+      product *= size;
+      output_shape.AddDim(size);
+    }
+  }
+  if (unknown_index != -1) {
+    OP_REQUIRES(
+        context, product > 0,
+        errors::InvalidArgument("reshape cannot infer the missing "
+                                "input size for an empty tensor unless all "
+                                "specified input sizes are non-zero"));
+    const int64 missing = dense_size / product;
+    OP_REQUIRES(
+        context, product * missing == dense_size,
+        errors::InvalidArgument(
+            "Input to reshape is a SparseTensor with ", dense_size,
+            " dense values, but the requested shape requires a multiple of ",
+            product, ". input_shape=", input_shape.DebugString(),
+            " output_shape=", output_shape.DebugString()));
+    output_shape.set_dim(unknown_index, missing);
+  }
+
+  OP_REQUIRES(
+      context, output_shape.num_elements() == dense_size,
+      errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
+                              " dense values, but the requested shape has ",
+                              output_shape.num_elements(),
+                              ". input_shape=", input_shape.DebugString(),
+                              " output_shape=", output_shape.DebugString()));
+
+  if (input_shape == output_shape) {
+    context->set_output(output_indices_idx, input_indices_in);
+    context->set_output(output_shape_idx, input_shape_in);
+    return;
+  }
+
+  gtl::InlinedVector<int64, 8> input_strides(input_rank);
+  if (input_rank > 0) {
+    input_strides[input_rank - 1] = 1;
+    for (int d = input_rank - 2; d >= 0; --d) {
+      input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+    }
+  }
+
+  gtl::InlinedVector<int64, 8> output_strides(output_rank);
+  if (output_rank > 0) {
+    output_strides[output_rank - 1] = 1;
+    for (int d = output_rank - 2; d >= 0; --d) {
+      output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
+    }
+  }
+
+  Tensor *result_indices = nullptr;
+  OP_REQUIRES_OK(context,
+                 context->allocate_output(output_indices_idx,
+                                          TensorShape({nnz, output_rank}),
+                                          &result_indices));
+  auto input_ind = input_indices_in.matrix<int64>();
+  auto output_ind = result_indices->matrix<int64>();
+  for (int i = 0; i < nnz; ++i) {
+    int64 id = 0;
+    for (int j = 0; j < input_rank; ++j) {
+      id += input_ind(i, j) * input_strides[j];
+    }
+    for (int j = 0; j < output_rank; ++j) {
+      output_ind(i, j) = id / output_strides[j];
+      id %= output_strides[j];
+    }
+  }
+
+  Tensor *result_shape = nullptr;
+  OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
+                                                   TensorShape({output_rank}),
+                                                   &result_shape));
+  auto output_shape_vec = result_shape->vec<int64>();
+  for (int j = 0; j < output_shape.dims(); ++j) {
+    output_shape_vec(j) = output_shape.dim_size(j);
+  }
+}
+
+template <typename T>
+class KPFusedSparseReshapeOp : public OpKernel {
+ public:
+  explicit KPFusedSparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) { }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& slice_input = context->input(0);
+    const Tensor& begin = context->input(1);
+    const Tensor& new_shape = context->input(2);
+    const Tensor& pack_const = context->input(3);
+
+    OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2"));
+    OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2"));
+    OP_REQUIRES(context, pack_const.dims() == 0, 
+                errors::InvalidArgument("pack_const must be a scalar"));
+    VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString();
+    VLOG(1) << "Input begin value: " << begin.DebugString();
+    VLOG(1) << "Input new_shape value: " << new_shape.DebugString();
+
+    OP_REQUIRES(context, begin.dims() == 1 && begin.dim_size(0) == 2,
+                errors::InvalidArgument("begin must be 1D with at least 2 elements"));
+    int32 col = begin.flat<int32>().data()[1];
+    OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)"));
+    int64_t num_rows = slice_input.dim_size(0);
+    auto slice_input_mat = slice_input.matrix<int64>();
+
+    VLOG(1) << "num_rows: " << num_rows;
+    VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0);
+    VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1);
+    VLOG(1) << "Column index from begin: " << col;
+
+    Tensor shape_in(DT_INT64, TensorShape({2}));
+    auto tensor_flat = shape_in.flat<int64>();
+    tensor_flat(0) = num_rows;
+    tensor_flat(1) = static_cast<int64>(pack_const.scalar<T>()());
+    
+    Tensor indices_in(DT_INT64, TensorShape({num_rows, 2}));
+    auto indices_in_mat = indices_in.matrix<int64>();
+    for (int i = 0; i < num_rows; ++i) {
+        indices_in_mat(i, 0) = i;
+        indices_in_mat(i, 1) = slice_input_mat(i, col);
+    }
+
+    ReshapeKp(context, indices_in, shape_in, new_shape, 0, 1);
+  }
+};
+
+#define REGISTER_KERNEL(type)                             \
+  REGISTER_KERNEL_BUILDER(Name("KPFusedSparseReshape")    \
+                              .Device(DEVICE_CPU)         \
+                              .TypeConstraint<type>("T"), \
+                              KPFusedSparseReshapeOp<type>)
+
+REGISTER_KERNEL(int64);
+REGISTER_KERNEL(int32);
+#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc
new file mode 100644
index 000000000..874d48a5e
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc
@@ -0,0 +1,281 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::FillValues;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedSparseReshapeTest : public OpsTestBase {
+ protected:
+  void RunValidCase(const TensorShape& slice_shape,
+                    const std::vector<int64>& slice_data,
+                    const std::vector<int32>& begin_val,
+                    const std::vector<int64>& new_shape_val,
+                    const std::vector<int64>& pack_const_val,
+                    const TensorShape& expected_indices_shape,
+                    const std::vector<int64>& expected_shape_val) {
+    TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+                     .Input(FakeInput(DT_INT64))   // slice_input
+                     .Input(FakeInput(DT_INT32))   // begin
+                     .Input(FakeInput(DT_INT64))   // new_shape
+                     .Input(FakeInput(DT_INT64))   // pack_const
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+
+    AddInputFromArray<int64>(slice_shape, slice_data);
+    AddInputFromArray<int32>(TensorShape({2}), begin_val);
+    AddInputFromArray<int64>(TensorShape({2}), new_shape_val);
+    AddInputFromArray<int64>(TensorShape({}), pack_const_val);
+
+    TF_ASSERT_OK(RunOpKernel());
+
+    // 输出0: result_indices
+    const Tensor& out_indices = *GetOutput(0);
+    EXPECT_EQ(out_indices.shape(), expected_indices_shape);
+
+    // 输出1: result_shape
+    const Tensor& out_shape = *GetOutput(1);
+    Tensor expected_shape_tensor(DT_INT64,
+                                 TensorShape({static_cast<int64>(expected_shape_val.size())}));
+    FillValues<int64>(&expected_shape_tensor, expected_shape_val);
+    ExpectTensorEqual<int64>(expected_shape_tensor, out_shape);
+  }
+
+  Status RunOpExpectFailure(const TensorShape& slice_shape,
+                            const std::vector<int64>& slice_data,
+                            const std::vector<int32>& begin_val,
+                            const std::vector<int64>& new_shape_val,
+                            const std::vector<int64>& pack_const_val) {
+    TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+                    .Input(FakeInput(DT_INT64))   // slice_input
+                    .Input(FakeInput(DT_INT32))   // begin
+                    .Input(FakeInput(DT_INT64))   // new_shape
+                    .Input(FakeInput(DT_INT64))   // pack_const
+                    .Finalize(node_def()));
+    TF_CHECK_OK(InitOp());
+
+    AddInputFromArray<int64>(slice_shape, slice_data);
+    AddInputFromArray<int32>(TensorShape({static_cast<int64>(begin_val.size())}), begin_val);
+    AddInputFromArray<int64>(TensorShape({static_cast<int64>(new_shape_val.size())}), new_shape_val);
+    AddInputFromArray<int64>(TensorShape({}), pack_const_val);
+
+    return RunOpKernel();
+  }
+};
+
+// ==================== 正向测试 ====================
+
+// 正常 reshape 案例
+// pack_const=2
+TEST_F(KPFusedSparseReshapeTest, Valid_NormalInput) {
+  RunValidCase(
+      TensorShape({4, 2}),   // slice_input shape
+      {0, 1,
+       1, 2,
+       2, 3,
+       3, 0},                // slice_input 数据
+      {0, 1},                // begin = (0,1),选第1列
+      {2, 4},                // new_shape = [2,4]
+      {2},                   // pack_const = [2]
+      TensorShape({4, 2}),   // 预期 indices 形状
+      {2, 4});               // 预期 shape
+}
+
+// pack_const = 1
+TEST_F(KPFusedSparseReshapeTest, Valid_PackConst1) {
+  RunValidCase(
+      TensorShape({1, 2}),   // slice_input shape
+      {0, 1},                // slice_input 数据
+      {0, 1},                // begin = (0,1),选第1列
+      {-1, 1},               // new_shape = [-1,1]
+      {1},                   // pack_const = [1]
+      TensorShape({1, 2}),   // 预期 indices 形状
+      {1, 1});               // 预期 shape
+}
+
+// ==================== 反向测试 ====================
+
+// 反例1:slice_input 不是二维
+TEST_F(KPFusedSparseReshapeTest, Invalid_SliceInputNot2D) {
+  Status s = RunOpExpectFailure(
+      TensorShape({4}), {0, 1, 2, 3},
+      {0, 0},
+      {2, 2},
+      {4});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "slice_input dims must == 2"));
+}
+
+// 反例2:new_shape dim size 不是 2
+TEST_F(KPFusedSparseReshapeTest, Invalid_NewShapeNotLen2) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 0},
+      {4, 2, 1},   // new_shape 多了1个元素
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "new_shape dim size must == 2"));
+}
+
+// 反例3:begin[1] 超出 slice_input 列数
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginOutOfRange) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 2},     // 超过列数
+      {2, 2},
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin[1] must < slice_input.dim_size(1)"));
+}
+
+// 反例4:target shape 有多个 -1
+TEST_F(KPFusedSparseReshapeTest, Invalid_MultipleUnknownDims) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 1},
+      {-1, -1},   // 两个 -1
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "only one output dimension may be -1"));
+}
+
+// 反例5:reshape 推断维度时,总元素数不能整除,导致无法匹配 --> product * missing != dense_size
+TEST_F(KPFusedSparseReshapeTest, Invalid_InferredShapeDoesNotMatch) {
+  TensorShape input_indices_shape({6, 2});  // 6 个非零元素,rank=2
+  std::vector<int64> input_indices_data = {
+    0, 0,
+    0, 1,
+    0, 2,
+    1, 0,
+    1, 1,
+    1, 2
+  };  // 对应 2x3 的 dense tensor
+
+  std::vector<int32> begin_val = {0, 0};         // 假设的 begin 输入
+  std::vector<int64> new_shape_val = {-1, 4};    // reshape 到 ?x4
+  std::vector<int64> pack_const_val = {1};
+
+  Status s = RunOpExpectFailure(
+      input_indices_shape,
+      input_indices_data,
+      begin_val,
+      new_shape_val,
+      pack_const_val);
+
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a SparseTensor with"));
+}
+
+// 反例6:reshape 后元素数量不匹配 --> output_shape.num_elements() != dense_size
+TEST_F(KPFusedSparseReshapeTest, Invalid_SizeMismatch) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 1},
+      {3, 3},   // 期望 9 元素,但输入 dense size = 4
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a tensor with"));
+}
+
+// 反例7:target_shape 包含负数但不是 -1
+TEST_F(KPFusedSparseReshapeTest, Invalid_NegativeDimNotMinusOne) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 0},
+      {2, -2},   // -2 是非法的
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "size 1 must be non-negative, not -2"))
+      << "Actual error: " << s.message();
+}
+
+// 反例8:target_shape 有 -1,但其他维度乘积为 0
+TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) {
+  // dense_size = 0(空 SparseTensor),target_shape = [-1, 0]
+  // product = 0 → 不允许 infer
+  Status s = RunOpExpectFailure(
+      TensorShape({0, 2}), {},           // 空的 slice_input
+      {0, 0},
+      {-1, 0},   // product = 0
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "reshape cannot infer the missing input size for an empty tensor"))
+      << "Actual error: " << s.message();
+}
+
+// 反例9:begin 是 1D 但长度为 1(不够 2 个元素)
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize1) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0},                // begin = [0],长度为 1
+      {2, 2},
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
+      << "Actual error: " << s.message();
+}
+
+// 反例10:begin 是 1D 但长度为 3(超过 2)
+TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize3) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2, 2}), {0, 1, 1, 0},
+      {0, 1, 2},          // begin = [0,1,2],长度为 3
+      {2, 2},
+      {2});
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
+      << "Actual error: " << s.message();
+}
+
+// 反例11:pack_const 是标量(0维)
+TEST_F(KPFusedSparseReshapeTest, Invalid_PackConstIsScalarButExpect1D) {
+  TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
+                  .Input(FakeInput(DT_INT64))   // slice_input
+                  .Input(FakeInput(DT_INT32))   // begin
+                  .Input(FakeInput(DT_INT64))   // new_shape
+                  .Input(FakeInput(DT_INT64))   // pack_const
+                  .Finalize(node_def()));
+  TF_CHECK_OK(InitOp());
+
+  AddInputFromArray<int64>(TensorShape({2, 2}), {0, 1, 1, 0});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+  AddInputFromArray<int64>(TensorShape({2}), {2, 2});
+  AddInputFromArray<int64>(TensorShape({1}), {1});  // pack_const = 标量 1(0维)
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "pack_const must be a scalar"))
+      << "Actual error: " << s.message();
+}
+
+}  // namespace
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc
new file mode 100644
index 000000000..33bbd312b
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc
@@ -0,0 +1,165 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+using namespace tensorflow;
+
+template <typename Tidx>
+class KPFusedSparseSegmentReduceOp : public OpKernel {
+public:
+  explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    int combiner_mode;
+    OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
+    OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
+                errors::InvalidArgument("combiner must be 0 or 1"));
+    is_mean_ = (combiner_mode == 1);
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_tensor = context->input(0);
+    const Tensor& indices = context->input(1);
+    const Tensor& slice_input = context->input(2);
+    const Tensor& begin = context->input(3);
+    const Tensor& begin_1 = context->input(4);
+
+    OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input must be 2-D"));
+    OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
+    OP_REQUIRES(context, begin.NumElements() == 2,  errors::InvalidArgument("begin must have 2 elements"));
+    OP_REQUIRES(context, begin_1.NumElements() == 1, errors::InvalidArgument("begin_1 must have 1 element"));
+    int64_t num_indices = indices.dim_size(0);
+    int64_t embedding_size = input_tensor.dim_size(1);
+    int32 col = begin.flat<int32>().data()[1];
+    int32 out_dim = static_cast<int32>(begin_1.flat<int32>()(0));
+
+    OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1), 
+                 errors::InvalidArgument("Column index out of range"));
+    OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
+                errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size"));
+
+    auto input_data = input_tensor.matrix<float>().data();
+    auto indices_vec = indices.vec<Tidx>();
+    auto slice_input_mat = slice_input.matrix<int64>();
+
+    // Calculate max segment_id
+    int64 max_seg_id = 0;
+    for (int32 i = 0; i < num_indices; ++i) {
+      int64 seg_id = slice_input_mat(i, col);
+      if (seg_id > max_seg_id) {
+        max_seg_id = seg_id;
+      }
+    }
+    const int64 batch_size = max_seg_id + 1;
+
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       0, TensorShape({batch_size, embedding_size}), &output));
+    output->flat<float>().setZero();
+    Tensor* slice_out = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, TensorShape({}), &slice_out));
+    if (out_dim == 0)
+      slice_out->scalar<int32>()() = batch_size;
+    else slice_out->scalar<int32>()() = embedding_size;
+
+    auto output_data = output->matrix<float>().data();
+
+    if (is_mean_) {
+      Tensor counts(DT_INT32, TensorShape({batch_size}));
+      counts.flat<int32>().setZero();
+      auto counts_vec = counts.flat<int32>();
+
+      for (int64 i = 0; i < num_indices; ++i) {
+        const int64 seg_id = slice_input_mat(i, col);
+        const Tidx data_row = indices_vec(i);
+        counts_vec(seg_id) += 1;
+
+        float* output_row = output_data + seg_id * embedding_size;
+        const float* input_data_row = input_data + data_row * embedding_size;
+        int64 j = 0;
+        for (; j + 3 < embedding_size; j += 4) {
+          float32x4_t out = vld1q_f32(output_row + j);
+          float32x4_t data = vld1q_f32(input_data_row + j);
+          out = vaddq_f32(out, data);
+          vst1q_f32(output_row + j, out);
+        }
+
+        for (; j < embedding_size; ++j) {
+          output_row[j] += input_data_row[j];
+        }
+      }
+
+      for (int64_t seg = 0; seg < batch_size; ++seg) {
+        const int32_t count = counts_vec(seg);
+        if (count > 0) {
+          const float inv_count = 1.0f / static_cast<float>(count);
+          const float32x4_t inv_count_vec = vdupq_n_f32(inv_count);
+
+          float* row_start = output_data + seg * embedding_size;
+          int64_t j = 0;
+
+          for (; j + 3 < embedding_size; j += 4) {
+            float32x4_t val = vld1q_f32(row_start + j);
+            val = vmulq_f32(val, inv_count_vec);
+            vst1q_f32(row_start + j, val);
+          }
+
+          for (; j < embedding_size; ++j) {
+            row_start[j] *= inv_count;
+          }
+        }
+      }
+    } else {
+      for (int64 i = 0; i < num_indices; ++i) {
+        const int64 seg_id = slice_input_mat(i, col);
+        const Tidx data_row = indices_vec(i);
+
+        float* output_row = output_data + seg_id * embedding_size;
+        const float* input_data_row = input_data + data_row * embedding_size;
+        int64 j = 0;
+        for (; j + 3 < embedding_size; j += 4) {
+          float32x4_t out = vld1q_f32(output_row + j);
+          float32x4_t data = vld1q_f32(input_data_row + j);
+          out = vaddq_f32(out, data);
+          vst1q_f32(output_row + j, out);
+        }
+
+        for (; j < embedding_size; ++j) {
+          output_row[j] += input_data_row[j];
+        }
+      }
+    }
+  }
+
+private:
+  bool is_mean_;
+};
+
+#define REGISTER_KERNEL(Tidx)                                \
+  REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduce") \
+                              .Device(DEVICE_CPU)            \
+                              .TypeConstraint<Tidx>("Tidx"), \
+                          KPFusedSparseSegmentReduceOp<Tidx>);
+REGISTER_KERNEL(int64)
+REGISTER_KERNEL(int32)
+#undef REGISTER_KERNEL
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc
new file mode 100644
index 000000000..3f0a4dfd8
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc
@@ -0,0 +1,387 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <arm_neon.h>
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "absl/container/flat_hash_map.h"
+
+using namespace tensorflow;
+
+template <typename Tidx>
+class KPFusedSparseSegmentReduceNonzeroOp : public OpKernel {
+public:
+  explicit KPFusedSparseSegmentReduceNonzeroOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    int combiner_mode;
+    OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
+    OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
+                errors::InvalidArgument("combiner must be 0 or 1"));
+    is_mean_ = (combiner_mode == 1);
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_tensor = context->input(0);
+    const Tensor& indices = context->input(1);
+    const Tensor& slice_input = context->input(2);
+    const Tensor& begin = context->input(3);
+
+    const int input_dims = input_tensor.dims();
+    OP_REQUIRES(context, input_dims == 1 || input_dims == 2,
+                errors::InvalidArgument("Input data must be a 1-D vector or 2-D matrix"));
+    OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
+    OP_REQUIRES(context, begin.NumElements() == 2,  errors::InvalidArgument("begin must have 2 elements"));
+
+    int64 num_indices = indices.dim_size(0);
+    int32 col = begin.flat<int32>().data()[1];
+
+    OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1),
+                 errors::InvalidArgument("Column index out of range"));
+    OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
+                errors::InvalidArgument("indices and slice_input.dim_size(0) should have same size"));
+
+    auto indices_vec = indices.vec<Tidx>();
+    auto slice_input_mat = slice_input.matrix<int64>();
+
+    // Calculate max segment_id
+    std::vector<int64> segment_ids(num_indices);
+    int64 max_seg_id = 0;
+    for (int64 i = 0; i < num_indices; ++i) {
+      int64 seg_id = slice_input_mat(i, col);
+      segment_ids[i] = seg_id;
+      if (seg_id > max_seg_id) {
+        max_seg_id = seg_id;
+      }
+    }
+
+    const int64 batch_size = max_seg_id + 1;
+
+    // 获取线程池
+    auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    const int num_threads = worker_threads ? worker_threads->num_threads : 1;
+
+    if (input_dims == 1) {
+      // -------------------------------------------------------
+      // 1-D input: each segment reduces to a scalar
+      // output_shape: [batch_size]
+      // output_indices: [num_nonzero, 1]  (segment_id)
+      // output_nonzero: [num_nonzero]
+      // -------------------------------------------------------
+      auto input_data = input_tensor.flat<float>();
+
+      Tensor* output_shape = nullptr;
+      OP_REQUIRES_OK(
+          context, context->allocate_output(0, TensorShape({1}), &output_shape));
+      output_shape->flat<int32>()(0) = static_cast<int32>(batch_size);
+
+      // 优化:使用map管理索引 + 连续内存存储
+      absl::flat_hash_map<int64, int64> seg_id_to_idx;
+      seg_id_to_idx.reserve(num_indices / 4 + 1);
+      
+      std::vector<float> segment_sums_data;
+      std::vector<int32_t> segment_counts_data;
+      std::vector<int64> segment_order;
+
+      if (is_mean_) {
+        // 聚合阶段 - 单线程(reduce维度)
+        for (int64 i = 0; i < num_indices; ++i) {
+          const int64 seg_id = segment_ids[i];
+          const Tidx data_row = indices_vec(i);
+          
+          auto it = seg_id_to_idx.find(seg_id);
+          int64 idx;
+          if (it == seg_id_to_idx.end()) {
+            idx = segment_order.size();
+            seg_id_to_idx.emplace(seg_id, idx);
+            segment_order.push_back(seg_id);
+            segment_sums_data.push_back(0.0f);
+            segment_counts_data.push_back(0);
+          } else {
+            idx = it->second;
+          }
+          
+          segment_sums_data[idx] += input_data(data_row);
+          segment_counts_data[idx]++;
+        }
+        
+        // 预计算逆元
+        std::vector<float> inv_counts(segment_order.size());
+        for (size_t s = 0; s < segment_order.size(); ++s) {
+          inv_counts[s] = (segment_counts_data[s] > 0) ? 
+              (1.0f / static_cast<float>(segment_counts_data[s])) : 0.0f;
+        }
+        
+        // 统计非零数量
+        int64 num_nonzero = 0;
+        for (size_t s = 0; s < segment_order.size(); ++s) {
+          float val = segment_sums_data[s] * inv_counts[s];
+          if (val != 0.0f) num_nonzero++;
+        }
+        
+        Tensor* output_indices = nullptr;
+        OP_REQUIRES_OK(context,
+                       context->allocate_output(1, TensorShape({num_nonzero, 1}),
+                                                &output_indices));
+        auto output_indices_data = output_indices->flat<int32>();
+
+        Tensor* output_nonzero = nullptr;
+        OP_REQUIRES_OK(context,
+                       context->allocate_output(2, TensorShape({num_nonzero}),
+                                                &output_nonzero));
+        auto output_nonzero_data = output_nonzero->flat<float>();
+        
+        // 直接填充输出 - 无中间容器
+        int64 idx = 0;
+        for (size_t s = 0; s < segment_order.size(); ++s) {
+          float val = segment_sums_data[s] * inv_counts[s];
+          if (val != 0.0f) {
+            output_indices_data(idx) = static_cast<int32>(segment_order[s]);
+            output_nonzero_data(idx) = val;
+            idx++;
+          }
+        }
+      } else {
+        // Sum模式
+        for (int64 i = 0; i < num_indices; ++i) {
+          const int64 seg_id = segment_ids[i];
+          const Tidx data_row = indices_vec(i);
+          
+          auto it = seg_id_to_idx.find(seg_id);
+          int64 idx;
+          if (it == seg_id_to_idx.end()) {
+            idx = segment_order.size();
+            seg_id_to_idx.emplace(seg_id, idx);
+            segment_order.push_back(seg_id);
+            segment_sums_data.push_back(0.0f);
+          } else {
+            idx = it->second;
+          }
+          
+          segment_sums_data[idx] += input_data(data_row);
+        }
+        
+        // 统计非零数量
+        int64 num_nonzero = 0;
+        for (size_t s = 0; s < segment_order.size(); ++s) {
+          if (segment_sums_data[s] != 0.0f) num_nonzero++;
+        }
+        
+        Tensor* output_indices = nullptr;
+        OP_REQUIRES_OK(context,
+                       context->allocate_output(1, TensorShape({num_nonzero, 1}),
+                                                &output_indices));
+        auto output_indices_data = output_indices->flat<int32>();
+
+        Tensor* output_nonzero = nullptr;
+        OP_REQUIRES_OK(context,
+                       context->allocate_output(2, TensorShape({num_nonzero}),
+                                                &output_nonzero));
+        auto output_nonzero_data = output_nonzero->flat<float>();
+        
+        // 直接填充输出
+        int64 idx = 0;
+        for (size_t s = 0; s < segment_order.size(); ++s) {
+          float val = segment_sums_data[s];
+          if (val != 0.0f) {
+            output_indices_data(idx) = static_cast<int32>(segment_order[s]);
+            output_nonzero_data(idx) = val;
+            idx++;
+          }
+        }
+      }
+
+    } else {
+      // -------------------------------------------------------
+      // 2-D input: each segment reduces to a vector of size embed_dim
+      // output_shape: [batch_size, embed_dim]
+      // output_indices: [num_nonzero, 2]  (segment_id, dim_index)
+      // output_nonzero: [num_nonzero]
+      // -------------------------------------------------------
+      const int64 embed_dim = input_tensor.dim_size(1);
+      auto input_data = input_tensor.matrix<float>();
+
+      Tensor* output_shape = nullptr;
+      OP_REQUIRES_OK(
+          context, context->allocate_output(0, TensorShape({2}), &output_shape));
+      output_shape->flat<int32>()(0) = static_cast<int32>(batch_size);
+      output_shape->flat<int32>()(1) = static_cast<int32>(embed_dim);
+
+      // 优化:map管理索引 + 连续内存存储数据
+      absl::flat_hash_map<int64, int64> seg_id_to_idx;
+      seg_id_to_idx.reserve(num_indices / 4 + 1);
+      
+      std::vector<float> segment_sums_data;
+      std::vector<int32_t> segment_counts_data;
+      std::vector<int64> segment_order;
+
+      // ========== 阶段1:单线程聚合(Reduce维度,避免数据竞争)==========
+      for (int64 i = 0; i < num_indices; ++i) {
+        const int64 seg_id = segment_ids[i];
+        const Tidx data_row = indices_vec(i);
+        
+        auto it = seg_id_to_idx.find(seg_id);
+        int64 idx;
+        if (it == seg_id_to_idx.end()) {
+          idx = segment_order.size();
+          seg_id_to_idx.emplace(seg_id, idx);
+          segment_order.push_back(seg_id);
+          segment_sums_data.resize((idx + 1) * embed_dim, 0.0f);
+          if (is_mean_) segment_counts_data.push_back(0);
+        } else {
+          idx = it->second;
+        }
+        
+        float* sum_ptr = &segment_sums_data[idx * embed_dim];
+        
+        // SIMD 加速聚合
+        int64 d = 0;
+        for (; d + 4 <= embed_dim; d += 4) {
+          float32x4_t input_vec = vld1q_f32(&input_data(data_row, d));
+          float32x4_t sum_vec = vld1q_f32(&sum_ptr[d]);
+          sum_vec = vaddq_f32(sum_vec, input_vec);
+          vst1q_f32(&sum_ptr[d], sum_vec);
+        }
+        // 处理剩余元素
+        for (; d < embed_dim; ++d) {
+          sum_ptr[d] += input_data(data_row, d);
+        }
+        
+        if (is_mean_) {
+          segment_counts_data[idx]++;
+        }
+      }
+      
+      const int64 num_unique_segments = segment_order.size();
+      
+      // 预计算逆元
+      std::vector<float> inv_counts(num_unique_segments);
+      if (is_mean_) {
+        for (int64 s = 0; s < num_unique_segments; ++s) {
+          inv_counts[s] = (segment_counts_data[s] > 0) ? 
+              (1.0f / static_cast<float>(segment_counts_data[s])) : 0.0f;
+        }
+      }
+      
+      // ========== 阶段2:多线程统计非零数量(非Reduce维度)==========
+      std::vector<int64> segment_nz_counts(num_unique_segments, 0);
+      
+      if (num_unique_segments >= 16 && num_threads > 1) {
+        auto count_work = [&](int64 start_seg, int64 end_seg) {
+          for (int64 s = start_seg; s < end_seg; ++s) {
+            float* sum_ptr = &segment_sums_data[s * embed_dim];
+            float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+            int64 count = 0;
+            for (int64 d = 0; d < embed_dim; ++d) {
+              if (sum_ptr[d] * inv_count != 0.0f) count++;
+            }
+            segment_nz_counts[s] = count;
+          }
+        };
+        
+        Shard(num_threads, worker_threads->workers, num_unique_segments,
+              num_unique_segments * embed_dim / num_threads, count_work);
+      } else {
+        for (int64 s = 0; s < num_unique_segments; ++s) {
+          float* sum_ptr = &segment_sums_data[s * embed_dim];
+          float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+          int64 count = 0;
+          for (int64 d = 0; d < embed_dim; ++d) {
+            if (sum_ptr[d] * inv_count != 0.0f) count++;
+          }
+          segment_nz_counts[s] = count;
+        }
+      }
+      
+      // 计算前缀和
+      std::vector<int64> segment_offsets(num_unique_segments + 1, 0);
+      for (int64 s = 0; s < num_unique_segments; ++s) {
+        segment_offsets[s + 1] = segment_offsets[s] + segment_nz_counts[s];
+      }
+      int64 total_nz = segment_offsets[num_unique_segments];
+      
+      // 分配输出 Tensor
+      Tensor* output_indices = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(1, TensorShape({total_nz, 2}),
+                                              &output_indices));
+      auto output_indices_data = output_indices->matrix<int32>();
+
+      Tensor* output_nonzero = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(2, TensorShape({total_nz}),
+                                              &output_nonzero));
+      auto output_nonzero_data = output_nonzero->flat<float>();
+      
+      // ========== 阶段3:多线程填充输出(非Reduce维度)==========
+      if (num_unique_segments >= 16 && num_threads > 1) {
+        auto output_work = [&](int64 start_seg, int64 end_seg) {
+          for (int64 s = start_seg; s < end_seg; ++s) {
+            int64 seg_id = segment_order[s];
+            float* sum_ptr = &segment_sums_data[s * embed_dim];
+            float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+            int64 out_idx = segment_offsets[s];
+            
+            for (int64 d = 0; d < embed_dim; ++d) {
+              float val = sum_ptr[d] * inv_count;
+              if (val != 0.0f) {
+                output_indices_data(out_idx, 0) = static_cast<int32>(seg_id);
+                output_indices_data(out_idx, 1) = static_cast<int32>(d);
+                output_nonzero_data(out_idx) = val;
+                out_idx++;
+              }
+            }
+          }
+        };
+        
+        Shard(num_threads, worker_threads->workers, num_unique_segments,
+              num_unique_segments * embed_dim / num_threads, output_work);
+      } else {
+        int64 idx = 0;
+        for (int64 s = 0; s < num_unique_segments; ++s) {
+          int64 seg_id = segment_order[s];
+          float* sum_ptr = &segment_sums_data[s * embed_dim];
+          float inv_count = is_mean_ ? inv_counts[s] : 1.0f;
+          
+          for (int64 d = 0; d < embed_dim; ++d) {
+            float val = sum_ptr[d] * inv_count;
+            if (val != 0.0f) {
+              output_indices_data(idx, 0) = static_cast<int32>(seg_id);
+              output_indices_data(idx, 1) = static_cast<int32>(d);
+              output_nonzero_data(idx) = val;
+              idx++;
+            }
+          }
+        }
+      }
+    }
+
+  }
+
+ private:
+  bool is_mean_;
+};
+
+#define REGISTER_KERNEL(Tidx)                                       \
+  REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduceNonzero") \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<Tidx>("Tidx"),        \
+                          KPFusedSparseSegmentReduceNonzeroOp<Tidx>);
+REGISTER_KERNEL(int64)
+REGISTER_KERNEL(int32)
+#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc
new file mode 100644
index 000000000..1a83beb99
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *     Unless required by applicable law or agreed to in writing, software
+ *     distributed under the License is distributed on an "AS IS" BASIS,
+ *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *     See the License for the specific language governing permissions and
+ *     limitations under the License.
+ *     ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseSegmentReduceNonzeroOpTest : public OpsTestBase {
+ protected:
+  void MakeOp(int combiner_mode) {
+    TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce_nonzero",
+                                "KPFusedSparseSegmentReduceNonzero")
+                     .Input(FakeInput(DT_FLOAT))  // data
+                     .Input(FakeInput(DT_INT32))  // indices
+                     .Input(FakeInput(DT_INT64))  // slice_input
+                     .Input(FakeInput(DT_INT32))  // begin
+                     .Attr("combiner", combiner_mode)
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+};
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceMean) {
+  MakeOp(1);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+  test::FillValues<int32>(&expected, {4});
+  test::ExpectTensorEqual<int32>(expected, *GetOutput(0));  // output_shape
+
+  Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
+  test::FillValues<int32>(&expected_1, {2, 3});
+  test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));  // output_indices
+
+  Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
+  test::FillValues<float>(&expected_2, {2, 2});
+  test::ExpectTensorEqual<float>(expected_2, *GetOutput(2));  // output_nonzero
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceSum) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+  test::FillValues<int32>(&expected, {4});
+  test::ExpectTensorEqual<int32>(expected, *GetOutput(0));  // output_shape
+
+  Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
+  test::FillValues<int32>(&expected_1, {2, 3});
+  test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));  // output_indices
+
+  Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
+  test::FillValues<float>(&expected_2, {4, 2});
+  test::ExpectTensorEqual<float>(expected_2, *GetOutput(2));  // output_nonzero
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidData) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({2, 2, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input data must be a 1-D vector or 2-D matrix") !=
+              std::string::npos);
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidSliceinput) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4, 1}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D") !=
+              std::string::npos);
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidbegin) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestColsOutOfBounds) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 4});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestIndicesOutOfBounds) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({8}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 1});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), 
+                  "indices and slice_input.dim_size(0) should have same size"));
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_test.cc
new file mode 100644
index 000000000..558ab8550
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce_test.cc
@@ -0,0 +1,205 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *     Unless required by applicable law or agreed to in writing, software
+ *     distributed under the License is distributed on an "AS IS" BASIS,
+ *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *     See the License for the specific language governing permissions and
+ *     limitations under the License.
+ *     ==============================================================================*/
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class KPFusedSparseSegmentReduceOpTest : public OpsTestBase {
+ protected:
+  void MakeOp(int combiner_mode) {
+    TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce",
+                                "KPFusedSparseSegmentReduce")
+                     .Input(FakeInput(DT_FLOAT))  // data
+                     .Input(FakeInput(DT_INT32))  // indices
+                     .Input(FakeInput(DT_INT64))  // slice_input
+                     .Input(FakeInput(DT_INT32))  // begin
+                     .Input(FakeInput(DT_INT32))  // begin_1
+                     .Attr("combiner", combiner_mode)
+                     .Finalize(node_def()));
+    TF_ASSERT_OK(InitOp());
+  }
+};
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceMean) {
+  MakeOp(1);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({1}), {1});
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+  test::FillValues<float>(&expected,
+                          {0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 4.0f, 3.0f, 4.0f});
+  test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+
+  Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected_1, {2});
+  test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceSum) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
+  test::FillValues<float>(&expected,
+                          {0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 8.0f, 3.0f, 4.0f});
+  test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+
+  Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
+  test::FillValues<int32>(&expected_1, {4});
+  test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestColsOutOfBounds) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 5});
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, Test) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({2}),
+                           {0, 2});  //  num_indices != slice_input.dim_size(0)
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), 
+                  "indices and slice_input.dim_zie(0) should have same size"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidData) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(
+      TensorShape({4, 2, 1}),
+      {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});  // data.dims() > 2
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "input must be 2-D"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidSliceinput) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(
+      TensorShape({3, 4, 1}),
+      {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});  // slice_input.dims() > 2
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({3}),
+                           {0, 2, 1});  // begin has 3 elements
+  AddInputFromArray<int32>(TensorShape({1}), {0});
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
+}
+
+TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin1) {
+  MakeOp(0);
+
+  AddInputFromArray<float>(TensorShape({4, 2}),
+                           {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
+  AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
+  AddInputFromArray<int64>(TensorShape({3, 4}),
+                           {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 2});
+  AddInputFromArray<int32>(TensorShape({2}), {0, 1});  // begin_1 has 2 elements
+
+  Status s = RunOpKernel();
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "begin_1 must have 1 element"));
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc
new file mode 100644
index 000000000..ce9173306
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc
@@ -0,0 +1,111 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include <algorithm>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/logging.h"
+
+using namespace tensorflow;
+
+class KPFusedSparseSelect : public OpKernel {
+public:
+  explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) {
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input_a = context->input(0);
+    const Tensor& input_b = context->input(1);
+    const Tensor& input_c = context->input(2);
+    const Tensor& greater = context->input(3); 
+    const Tensor& equal1 = context->input(4);
+    const Tensor& equal2 = context->input(5);
+    const Tensor& equal3 = context->input(6);
+
+    int32_t equal1_val = equal1.flat<int32_t>()(0);
+    int32_t equal2_val = equal2.flat<int32_t>()(0);
+    int32_t equal3_val = equal3.flat<int32_t>()(0);
+    VLOG(1) << "equal1_val: " << equal1_val;
+    VLOG(1) << "equal2_val: " << equal2_val;
+    VLOG(1) << "equal3_val: " << equal3_val;
+
+    int32_t greater_val = greater.flat<int32_t>()(0);
+    auto a_flat = input_a.flat<int32_t>();
+    auto b_flat = input_b.flat<int32_t>();
+    auto c_flat = input_c.flat<int32_t>();
+    VLOG(1) << "input_a shape: " << input_a.shape().DebugString();
+    VLOG(1) << "input_b shape: " << input_b.shape().DebugString();
+    VLOG(1) << "input_c shape: " << input_c.shape().DebugString();
+    OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(),
+                errors::InvalidArgument("Input num elements of a and b must match"));
+    OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(),
+                errors::InvalidArgument("Input num elements of a and c must match"));
+    auto N = input_a.NumElements();
+
+    Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> a_reshaped_tensor(a_flat.data(), N, 1);
+    Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> b_reshaped_tensor(b_flat.data(), N, 1);
+    Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> c_reshaped_tensor(c_flat.data(), N, 1);
+
+    Tensor* output_x = nullptr;
+    Tensor* output_y = nullptr;
+    Tensor* output_w = nullptr;
+
+    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x));
+    OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y));
+    OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w));
+
+    Eigen::TensorMap<Eigen::Tensor<int32_t, 2, Eigen::RowMajor>> out_x(
+        output_x->flat<int32_t>().data(),
+        output_x->dim_size(0),
+        output_x->dim_size(1)
+    );
+
+    Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_y(
+        output_y->flat<float>().data(),
+        output_y->dim_size(0),
+        output_y->dim_size(1)
+    );
+
+    Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_w(
+        output_w->flat<float>().data(),
+        output_w->dim_size(0),
+        output_w->dim_size(1)
+    );
+
+    auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+    const int64 cost_per_unit = std::max(N / worker_threads->num_threads, int64(10));
+    
+    auto work = [&](int64 start, int64 end) {
+      for (int64 i = start; i < end; i++) {
+        // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f
+        float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f;
+        float res_equal1 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater;  // Fill.2409-->1.0f
+        float res_equal2 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : res_equal1;  // Fill.2409-->1.0f
+        out_x(i, 0) = a_reshaped_tensor(i, 0);  // Reshape.2401
+        out_y(i, 0) = res_equal2;
+        out_w(i, 0) = res_equal2;  // Mul.2419 硬编码 1.0f * input
+        out_w(i, 1) = 1.0f;  // select_2427被消除,直接使用Fill.2422-->1.0f
+      }
+    };
+    Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU),
+                        KPFusedSparseSelect);
\ No newline at end of file
diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc
new file mode 100644
index 000000000..d649d63e6
--- /dev/null
+++ b/tensorflow/core/kernels/embedding_fused_sparse_select_test.cc
@@ -0,0 +1,182 @@
+/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace {
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::int64;
+using tensorflow::int32;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpsTestBase;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::test::ExpectClose;
+using tensorflow::test::FillValues;
+using tensorflow::test::AsTensor;
+using tensorflow::test::ExpectTensorEqual;
+
+class KPFusedSparseSelectTest : public OpsTestBase {
+ protected:
+  void RunValidCase(
+      const TensorShape& shape,
+      const std::vector<int32>& a_data,
+      const std::vector<int32>& b_data,
+      const std::vector<int32>& c_data,
+      int32_t greater_val,
+      int32_t equal1_val,
+      int32_t equal2_val,
+      const std::vector<float>& expected_y,
+      const std::vector<float>& expected_w_col0) {
+    
+    TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))  // greater
+                     .Input(FakeInput(DT_INT32))  // equal1
+                     .Input(FakeInput(DT_INT32))  // equal2
+                     .Input(FakeInput(DT_INT32))  // equal3
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+
+    AddInputFromArray<int32>(shape, a_data);
+    AddInputFromArray<int32>(shape, b_data);
+    AddInputFromArray<int32>(shape, c_data);
+    AddInputFromArray<int32>(TensorShape({}), {greater_val});  // scalar
+    AddInputFromArray<int32>(TensorShape({}), {equal1_val});
+    AddInputFromArray<int32>(TensorShape({}), {equal2_val});
+    AddInputFromArray<int32>(TensorShape({}), {0});  // equal3_val (未使用)
+
+    TF_ASSERT_OK(RunOpKernel());
+
+    const Tensor& out_x = *GetOutput(0);
+    const Tensor& out_y = *GetOutput(1);
+    const Tensor& out_w = *GetOutput(2);
+
+    int32 Num_elements = expected_y.size();
+    // 验证 output_x: 就是 input_a
+    std::vector<int32_t> a_data_int(a_data.begin(), a_data.end());
+    ExpectTensorEqual<int32>(out_x, AsTensor<int32>(a_data_int, {Num_elements, 1}));
+
+    // 验证 output_y
+    ExpectTensorEqual<float>(out_y, AsTensor<float>(expected_y, {Num_elements, 1}));
+    // 验证 output_w 第一列
+    auto w_mat = out_w.matrix<float>();
+    for (int i = 0; i < w_mat.dimension(0); ++i) {
+      EXPECT_FLOAT_EQ(w_mat(i, 0), expected_w_col0[i]);
+      EXPECT_FLOAT_EQ(w_mat(i, 1), 1.0f);  // 第二列必须是 1.0
+    }
+  }
+
+  Status RunOpExpectFailure(
+      const TensorShape& shape,
+      const std::vector<int32>& a_data,
+      const std::vector<int32>& b_data,
+      const std::vector<int32>& c_data,
+      int32_t greater_val,
+      int32_t equal1_val,
+      int32_t equal2_val) {
+    
+    TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Finalize(node_def()));
+    TF_CHECK_OK(InitOp());
+    TensorShape b_shape({static_cast<int64>(b_data.size())});
+    TensorShape c_shape({static_cast<int64>(c_data.size())});
+    AddInputFromArray<int32>(shape, a_data);
+    AddInputFromArray<int32>(b_shape, b_data);
+    AddInputFromArray<int32>(c_shape, c_data);
+    AddInputFromArray<int32>(TensorShape({}), {greater_val});
+    AddInputFromArray<int32>(TensorShape({}), {equal1_val});
+    AddInputFromArray<int32>(TensorShape({}), {equal2_val});
+    AddInputFromArray<int32>(TensorShape({}), {0});
+
+    return RunOpKernel();
+  }
+};
+
+// ==================== 正向测试 ====================
+// 更多正向验证参考 fused_embedding_sparse_select_test.py
+TEST_F(KPFusedSparseSelectTest, Valid_NormalInput) {
+  RunValidCase(
+      TensorShape({3}),                   // shape
+      {5, 3, 8},                          // input_a
+      {1, 2, 1},                          // input_b
+      {9, 8, 7},                          // input_c (未使用)
+      4,                                  // greater_val
+      1,                                  // equal1_val
+      3,                                  // equal2_val
+      {1.0f, 0.0f, 1.0f},                 // expected_y
+      {1.0f, 0.0f, 1.0f}                  // expected_w_col0
+  );
+}
+
+TEST_F(KPFusedSparseSelectTest, Valid_2DInput) {
+  RunValidCase(
+      TensorShape({2, 2}),
+      {6, 3, 8, 2},
+      {2, 1, 3, 4},
+      {0, 0, 0, 0},
+      5,
+      2,
+      3,
+      {1.0f, 0.0f, 1.0f, 0.0f},
+      {1.0f, 0.0f, 1.0f, 0.0f}
+  );
+}
+// ==================== 反向测试 ====================
+// 反例1:input_a 与 input_b 元素数不匹配
+TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AB) {
+  Status s = RunOpExpectFailure(
+      TensorShape({3}),           // a 有 3 个元素
+      {1, 2, 3},
+      {4, 5},                     // b 有 2 个元素 → 不匹配!
+      {6, 7, 8},
+      0, 1, 2
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and b must match"));
+}
+
+// 反例2:input_a 与 input_c 元素数不匹配
+TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AC) {
+  Status s = RunOpExpectFailure(
+      TensorShape({2}),
+      {1, 2},
+      {3, 4},
+      {5},                        // c 只有 1 个元素 → 不匹配!
+      0, 1, 2
+  );
+  EXPECT_FALSE(s.ok());
+  EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and c must match"));
+}
+
+}
\ No newline at end of file
diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD
index 4659963e9..cf99e340f 100644
--- a/tensorflow/core/ops/BUILD
+++ b/tensorflow/core/ops/BUILD
@@ -62,6 +62,7 @@ tf_gen_op_libs(
         "decode_proto_ops",
         "encode_proto_ops",
         "experimental_dataset_ops",
+        "embedding_fused_ops",
         "filesystem_ops",
         "function_ops",
         "functional_ops",
@@ -357,7 +358,12 @@ cc_library(
         ":ktfop_ops_op_lib",
     ]) + if_fused_embedding([
         ":fused_embedding_ops_op_lib",
-    ]),
+    ]) + select({
+        "@platforms//cpu:aarch64": [
+            ":embedding_fused_ops_op_lib"
+        ],
+        "//conditions:default": [],
+    }),
     alwayslink = 1,
 )
 
diff --git a/tensorflow/core/ops/embedding_fused_ops.cc b/tensorflow/core/ops/embedding_fused_ops.cc
new file mode 100644
index 000000000..f5e199641
--- /dev/null
+++ b/tensorflow/core/ops/embedding_fused_ops.cc
@@ -0,0 +1,134 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPFusedSparseSegmentReduce")
+    .Input("data: float")
+    .Input("indices: Tidx")
+    .Input("slice_input: int64")
+    .Input("begin: int32")
+    .Input("begin_1: int32")
+    .Attr("combiner: int = 1")  // 0 for SUM, 1 for MEAN
+    .Attr("Tidx: {int32, int64} = DT_INT32")
+    .Output("output: float")
+    .Output("slice_output: int32")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseSegmentReduceNonzero")
+    .Input("data: float")
+    .Input("indices: Tidx")
+    .Input("slice_input: int64")
+    .Input("begin: int32")
+    .Attr("combiner: int = 1")  // 0 for SUM, 1 for MEAN
+    .Attr("Tidx: {int32, int64} = DT_INT32")
+    .Output("output_shape: int32")
+    .Output("output_indices: int32")
+    .Output("output_nonzero: float")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedEmbeddingPaddingFast")
+    .Input("input0: int64")
+    .Input("input1: float")
+    .Input("input2: int32")
+    .Input("input3: int32")
+    .Input("pack: int32")
+    .Output("output0: int32")
+    .Output("output1: int32")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle scalar_shape = c->Scalar();
+      c->set_output(0, scalar_shape);
+      c->set_output(1, scalar_shape);
+      return OkStatus();
+    });
+
+REGISTER_OP("KPFusedEmbeddingPadding")
+    .Input("input0: int64")
+    .Input("input1: float")
+    .Input("input2: int32")
+    .Input("input3: int32")
+    .Input("pack: int32")
+    .Output("output0: int32")
+    .Output("output1: float")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle out;
+      ShapeHandle scalar_shape = c->Scalar();
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
+      c->set_output(0, scalar_shape);
+      c->set_output(1, out);
+      return OkStatus();
+    });
+
+REGISTER_OP("KPFusedSparseSelect")
+    .Input("input_a: int32")
+    .Input("input_b: int32")
+    .Input("input_c: int32")
+    .Input("greater: int32")
+    .Input("equal1: int32")
+    .Input("equal2: int32")
+    .Input("equal3: int32")
+    .Output("output_x: int32")
+    .Output("output_y: float")
+    .Output("output_w: float")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseReshape")
+    .Input("slice_input: int64")
+    .Input("begin: int32")
+    .Input("new_shape: int64")
+    .Input("pack_const: T")
+    .Output("out_indices: int64")
+    .Output("out_shape: int64")
+    .Attr("T: {int32, int64}")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedSparseDynamicStitch")
+    .Input("x: int64")
+    .Input("variables: N * float")
+    .Output("output: float")
+    .Attr("N: int >= 1")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedGather")
+    .Input("data: float")
+    .Input("slice_input: int64")
+    .Input("begin: int32")
+    .Output("out_shape: int64")
+    .Output("out_indices: int32")
+    .Output("out_data: float")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("KPFusedEmbeddingActionIdGather")
+    .Input("input0: Tindices1")
+    .Input("input1: float")
+    .Input("input2: Tindices2")
+    .Input("input3: int32")
+    .Input("pack: int32")
+    .Attr("Tindices1: {int32, int64} = DT_INT64")
+    .Attr("Tindices2: {int32, int64} = DT_INT32")
+    .Output("output0: float")
+    .SetShapeFn(shape_inference::UnknownShape);
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 9810f8acd..2bf06943e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -314,7 +314,12 @@ py_strict_library(
         "//tensorflow/python/util:tf_decorator_export",
         "//tensorflow/python/util:tf_export",
         "//third_party/py/numpy",
-    ],
+    ] + select({
+        "@platforms//cpu:aarch64": [
+            "//tensorflow/python/ops:embedding_fused_ops_gen",
+        ],
+        "//conditions:default": [],
+    }),
 )
 
 py_strict_library(
@@ -450,7 +455,12 @@ py_strict_library(
         "//tensorflow/python/util:tf_decorator",
         "//tensorflow/python/util:tf_decorator_export",
         "//tensorflow/python/util:tf_export",
-    ],
+    ] + select({
+        "@platforms//cpu:aarch64": [
+            "//tensorflow/python/ops:embedding_fused_ops_gen",
+        ],
+        "//conditions:default": [],
+    }),
 )
 
 # Necessary for the pywrap inclusion below.
diff --git a/tensorflow/python/grappler/embedding_fused_test/framework/runner.py b/tensorflow/python/grappler/embedding_fused_test/framework/runner.py
new file mode 100644
index 000000000..94752287b
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/framework/runner.py
@@ -0,0 +1,204 @@
+import os
+import time
+import shutil
+import glob
+import json
+import sys
+import re
+import struct
+import tensorflow.compat.v1 as tf
+from dataclasses import dataclass, field
+from typing import List, Callable, Dict, Any, Union, Optional
+from tensorflow.python.client import timeline
+import numpy as np
+import multiprocessing as mp
+import pickle
+from tensorflow.python.client import timeline
+from tensorflow.core.protobuf import rewriter_config_pb2
+
+tf.disable_eager_execution()
+
+
+@dataclass
+class TestCase:
+    name: str
+    op_fn: Callable
+    input_fn: Callable
+    check_fn: Optional[Callable] = None
+    fused_op_name: str = ""
+    start_op_name: str = ""
+    end_op_name: str = ""
+    is_fused: bool = True
+    num_iters: int = 0
+    optimize_percent: int = 0
+    meta: Dict = field(default_factory=dict)
+
+class CheckFuncClass:
+    def check_fn_default(A, B, meta):
+        if (len(A) != len(B)):
+            return False
+        is_check_OK = True
+        for i in range(len(A)):
+            if (A[i].dtype == tf.float32):
+                is_check_OK &= np.testing.assert_allclose(A[i], B[i], rtol=1e-3, atol=1e-4) == None
+            else:
+                #整数类型严格判断相等, 当前浮点类型仅有float32
+                np.testing.assert_array_equal(A[i], B[i]) # 如果失败抛出异常, 如果成功无返回值,
+
+        return is_check_OK
+
+class UniversalOpBenchmark:
+    def __init__(self, log_dir="timeline"):
+        self.log_dir = log_dir
+        self.results = []
+
+    def generate_timeline(self, step_stats_list, timeline_file):
+        if not os.path.exists("timeline"):
+            os.makedirs("timeline")
+        ctf_list = []
+        for step_stats in step_stats_list:
+            tl = timeline.Timeline(step_stats)
+            ctf_list.append(json.loads(tl.generate_chrome_trace_format()))
+        with open(f"timeline/{timeline_file}", "w") as f:
+            json.dump(ctf_list, f, indent=2)
+
+    def is_fused_op_exist(self, timeline_file, op_name):
+        """从 timeline JSON 文件中检查指定算子(fusedOp)是否存在"""
+        with open(f"timeline/{timeline_file}", "r") as f:
+            trace_events = json.load(f)[0]["traceEvents"]  # timeline.json的格式
+        op_exists = any(e.get("name") == op_name for e in trace_events if "dur" in e)
+        return op_exists
+
+    def extract_op_dur(self, timeline_file, op_name, times=1):
+        """从 timeline JSON 文件中提取指定算子(fusedOp)的平均耗时(μs)"""
+        with open(f"timeline/{timeline_file}", "r") as f:
+            trace_events_list = json.load(f)  # timeline.json的格式
+        durations_list = []
+        for trace_events in trace_events_list:
+            durations = [e["dur"] for e in trace_events["traceEvents"] if e.get("name") == op_name and "dur" in e]
+            durations_list.append(durations[0])
+        if len(durations_list) != times:
+            raise ValueError(f"Expected {times} durations for {op_name}, but got {len(durations)}")
+        return np.mean(durations_list)
+
+
+    def extract_op_total_time(self, timeline_file, start_op, end_op, times):
+        """计算从 start_op 到 end_op 的总耗时的平均值(包含调度空隙)"""
+        with open(f"timeline/{timeline_file}", "r") as f:
+            trace_events_list = json.load(f)
+        time_list = []
+        for trace_events in trace_events_list:
+            start_event = next(e for e in trace_events["traceEvents"] if e.get("args", {}).get("name") == start_op)
+            end_event = next(e for e in trace_events["traceEvents"] if e.get("args", {}).get("name") == end_op)
+            start_time = start_event["ts"]
+            end_time = end_event["ts"] + end_event["dur"]  # ts 是开始时间,dur是算子的持续时间
+            total_time = end_time - start_time
+            time_list.append(total_time)
+        if len(time_list) != times:
+            raise ValueError(f"Expected {times} total times for {start_op} to {end_op}, but got {len(time_list)}")
+        return np.mean(time_list)
+
+    def parse_performance_time(self, embedding_fused_enable, raw_inputs, test_case: TestCase):
+        print(f"Testing: {test_case.name} ...")
+        os.environ["ANNC_FUSED_ALL"] = str(embedding_fused_enable)
+
+        res_node, feed_dict = test_case.op_fn(raw_inputs, test_case.meta)
+        config = tf.compat.v1.ConfigProto()
+        config.inter_op_parallelism_threads = 16
+        config.intra_op_parallelism_threads = 16
+
+        run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
+        run_metadata = tf.compat.v1.RunMetadata()
+        all_step_stats = []
+        filename = f"{test_case.name}_performance_{embedding_fused_enable}.timeline.json"
+        if embedding_fused_enable != 0:
+            # 开启融合条件下关闭部分开关防止中间进行部分非预期优化
+            config.graph_options.rewrite_options.constant_folding = rewriter_config_pb2.RewriterConfig.OFF
+            config.graph_options.rewrite_options.arithmetic_optimization = rewriter_config_pb2.RewriterConfig.OFF
+            config.graph_options.rewrite_options.remapping = rewriter_config_pb2.RewriterConfig.AGGRESSIVE
+        with tf.Session(config=config) as sess:
+            time.sleep(1)
+            for _ in range(10):
+                sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+            for _ in range(test_case.num_iters):
+                sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+                all_step_stats.append(run_metadata.step_stats)
+            self.generate_timeline(all_step_stats, filename)
+            if embedding_fused_enable != 0:
+                return self.extract_op_dur(filename, test_case.fused_op_name, test_case.num_iters)
+            else:
+                return self.extract_op_total_time(filename, test_case.start_op_name, test_case.end_op_name, test_case.num_iters)
+    
+    def run_performance_test(self, test_case: TestCase):
+        if test_case.num_iters == 0:
+            return True
+
+        print(f"Performance testing: {test_case.name} ...")
+
+        optimize_percent = test_case.optimize_percent
+        operator_name = test_case.fused_op_name
+        raw_inputs = test_case.input_fn(test_case.meta)
+
+        no_fused_time = self.parse_performance_time(0, raw_inputs, test_case)
+        fused_time = self.parse_performance_time(1, raw_inputs, test_case)
+        if fused_time <= 0.0 or no_fused_time <= 0.0:
+            print(f"⚠️ 获取平均运行时长失败, 无法进行比较")
+            return False
+        real_percent = 100 * (no_fused_time/fused_time - 1)
+        if real_percent < optimize_percent:
+            raise Exception(f"⚠️ 性能提升{real_percent:.2f}%, 低于预期{optimize_percent}%, embedding算子融合开启状态下耗时{fused_time:.2f}us, 关闭状态下耗时{no_fused_time:.2f}us")
+            return False
+
+        print(f"性能测试通过, 性能提升{real_percent:.2f}%, embedding算子融合开启状态下耗时{fused_time:.2f}us, 关闭状态下耗时{no_fused_time:.2f}us")
+        return True
+
+    def run_function_test(self, test_case: TestCase):
+        print(f"Function testing: {test_case.name} ...")
+
+        raw_inputs = test_case.input_fn(test_case.meta)
+
+        def execute_variant(enable_embedding_fused, raw_inputs, test_case: TestCase):
+            os.environ["ANNC_FUSED_ALL"] = str(enable_embedding_fused)
+            tf.reset_default_graph()
+
+            placeholders = []
+            res_node, feed_dict = test_case.op_fn(raw_inputs, test_case.meta)
+            config = tf.ConfigProto(
+                inter_op_parallelism_threads=16,
+                intra_op_parallelism_threads=16
+            )
+
+            config.graph_options.rewrite_options.constant_folding = rewriter_config_pb2.RewriterConfig.OFF
+            config.graph_options.rewrite_options.arithmetic_optimization = rewriter_config_pb2.RewriterConfig.OFF
+            config.graph_options.rewrite_options.remapping = rewriter_config_pb2.RewriterConfig.AGGRESSIVE
+            with tf.Session(config=config) as sess:
+                is_fused = False
+                if enable_embedding_fused == 0:
+                    result = sess.run(res_node, feed_dict=feed_dict)
+                else:
+                    run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE)
+                    run_metadata = tf.compat.v1.RunMetadata()
+                    filename = f"{test_case.name}_func_{enable_embedding_fused}.timeline.json"
+                    result = sess.run(res_node, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
+                    self.generate_timeline([run_metadata.step_stats], filename)
+                    is_fused = self.is_fused_op_exist(filename, test_case.fused_op_name)
+                return result, is_fused
+            return
+
+        embedding_data, is_fused = execute_variant(1, raw_inputs, test_case)
+        if is_fused != test_case.is_fused:
+            print(f"⚠️ {test_case.fused_op_name}算子未做融合")
+            raise Exception(f"⚠️ {test_case.fused_op_name}算子未做融合")
+            return False
+        no_embedding_data, _ = execute_variant(0, raw_inputs, test_case)
+        if test_case.check_fn is None:
+            is_correct = CheckFuncClass.check_fn_default(no_embedding_data, embedding_data, test_case.meta)
+        else:
+            is_correct = test_case.check_fn(no_embedding_data, embedding_data, test_case.meta)
+
+        if not is_correct:
+            print(f"⚠️ 误差较大")
+            raise Exception(f"⚠️ 误差较大")
+            return False
+        print("功能测试通过")
+        return True
diff --git a/tensorflow/python/grappler/embedding_fused_test/main.py b/tensorflow/python/grappler/embedding_fused_test/main.py
new file mode 100644
index 000000000..b061f5863
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/main.py
@@ -0,0 +1,51 @@
+import argparse
+import importlib
+import pkgutil
+import traceback
+import os
+
+def main():
+    parser = argparse.ArgumentParser(description="TF Op Benchmark Framework with KDNN Control")
+    parser.add_argument('--op', type=str, help='指定要运行的算子模块名')
+    parser.add_argument('--list', action='store_true', help='列出所有模块')
+    parser.add_argument('--performance_test', choices=['True', 'False'], default='True', help='是否运行性能测试模块')
+
+    args = parser.parse_args()
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #日志太多, 关掉
+
+    import tensorflow as TF
+    from framework.runner import UniversalOpBenchmark
+    import ops
+
+    # 4. 扫描模块
+    available_ops = {
+        name: f'ops.{name}' 
+        for loader, name, is_pkg in pkgutil.iter_modules(ops.__path__)
+    }
+
+    if args.list:
+        print("📁 可用算子列表:")
+        for name in available_ops: print(f"  - {name}")
+        return
+
+    target_modules = [available_ops[args.op]] if args.op else list(available_ops.values())
+    root_log_dir = "bench_logs"
+    bench = UniversalOpBenchmark(log_dir=root_log_dir)
+    
+    for module_path in target_modules:
+        try:
+            module = importlib.import_module(module_path)
+            if hasattr(module, 'get_test_cases'):
+                cases = module.get_test_cases()
+                for case in cases:
+                    bench.run_function_test(case)
+                if args.performance_test == "True":
+                    for case in cases:
+                        bench.run_performance_test(case)
+        except Exception as e:
+            print(f"❌ 运行模块 {module_path} 失败: {e}")
+            traceback.print_exc()
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingActionIdGather_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingActionIdGather_op.py
new file mode 100644
index 000000000..41bcdf103
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingActionIdGather_op.py
@@ -0,0 +1,150 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedEmbeddingActionIdGatherOp(input0, input1, input2, input3, pack):
+        gather1 = tf.gather(input1, input0, axis=0)
+        gather2 = tf.gather(gather1, input2, axis=0)
+        pack1 = tf.stack([input3, pack], axis=0)
+        pack2 = tf.stack([input3, -1], axis=0)
+        reshape = tf.reshape(gather2, pack2)
+        fill = tf.fill(pack1, tf.constant(0, dtype=tf.float32))
+        output = tf.concat([reshape, fill], axis=-1)
+        return output
+
+    def KPFusedEmbeddingActionIdGather_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        input0_type = np.int32
+        if "input0_type" in meta.keys():
+            input0_type = meta["input0_type"]
+        input2_type = np.int32
+        if "input2_type" in meta.keys():
+            input2_type = meta["input2_type"]
+        input0 = tf.compat.v1.placeholder(input0_type, shape=input['input0'].shape, name="input0")
+        input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+        input2 = tf.compat.v1.placeholder(input2_type, shape=input['input2'].shape, name="input2")
+        input3 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input3")
+        pack = tf.compat.v1.placeholder(tf.int32, shape=(), name="pack")
+
+        feed_dict = {
+            input0: input['input0'],
+            input1: input['input1'],
+            input2: input['input2'],
+            input3: input['input3'],
+            pack: input['pack']
+        }
+        
+        result = KPFusedEmbeddingActionIdGatherOp(input0, input1, input2, input3, pack)
+        return result, feed_dict
+
+    def build_input_case_1(meta):
+        indices1_shape = (8, 10)
+        indices2_shape = (5, 6)
+        params_shape = (80, 300)
+        input = {}
+        input0_type = np.int32
+        if "input0_type" in meta.keys():
+            input0_type = meta["input0_type"]
+        input2_type = np.int32
+        if "input2_type" in meta.keys():
+            input2_type = meta["input2_type"]
+        # 根据核心函数入参名生成输入数据
+        input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(input0_type)
+        input["input1"] = np.random.random(params_shape).astype(np.float32)
+        input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(input2_type)
+        input["input3"] = params_shape[0]
+        input["pack"] = 1680
+        return input
+
+    def build_input_3D_case_1(meta):
+        indices1_shape = (8, 10, 2)
+        indices2_shape = (5, 6)
+        params_shape = (160, 300)
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(np.int32)
+        input["input1"] = np.random.random(params_shape).astype(np.float32)
+        input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(np.int32)
+        input["input3"] = params_shape[0]
+        input["pack"] = 1680
+        return input
+
+    def build_input_3D_case_2(meta):
+        indices1_shape = (8, 10)
+        indices2_shape = (5, 6, 2)
+        params_shape = (160, 300)
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["input0"] = np.random.randint(0, params_shape[0], indices1_shape).astype(np.int32)
+        input["input1"] = np.random.random(params_shape).astype(np.float32)
+        input["input2"] = np.random.randint(0, indices1_shape[0], indices2_shape).astype(np.int32)
+        input["input3"] = params_shape[0]
+        input["pack"] = 1680
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_case_1_int32_int32",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_case_1,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            start_op_name = "PartitionedCall_1/stack_1",
+            end_op_name = "PartitionedCall_1/concat",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=50,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_case_1_int64_int32",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_case_1,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            start_op_name = "PartitionedCall_3/stack_1",
+            end_op_name = "PartitionedCall_3/concat",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=50,
+            meta = {"input0_type": np.int64},
+        ),
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_case_1_int32_int64",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_case_1,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            start_op_name = "PartitionedCall_5/stack_1",
+            end_op_name = "PartitionedCall_5/concat",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=50,
+            meta = {"input2_type": np.int64},
+        ),
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_case_1_int64_int64",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_case_1,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            start_op_name = "PartitionedCall_7/stack_1",
+            end_op_name = "PartitionedCall_7/concat",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=50,
+            meta = {"input0_type": np.int64, "input2_type": np.int64},
+        ),
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_3D_case_1",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_3D_case_1,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            is_fused=False,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingActionIdGather_3D_case_2",
+            op_fn = KPFusedEmbeddingActionIdGather_graph,
+            input_fn = build_input_3D_case_2,
+            fused_op_name = "KPFusedEmbeddingActionIdGather",
+            is_fused=False,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPaddingFast_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPaddingFast_op.py
new file mode 100644
index 000000000..be370331e
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPaddingFast_op.py
@@ -0,0 +1,122 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedEmbeddingPaddingFastOp(input0, input1, input2, input3, pack, dims):
+        cast = tf.cast(input0, tf.int32)
+        begin = tf.constant([0], dtype=tf.int32)
+        end = tf.constant([1], dtype=tf.int32)
+        strides = tf.constant([1], dtype=tf.int32)
+        hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+        sub_out = hash_rows - input2
+        if dims == 3:
+            fill_shape = tf.stack([sub_out, pack, 5], axis=0)
+            fill = tf.fill(fill_shape, tf.constant(0, dtype=tf.float32))
+        else:
+            pack_op = tf.stack([sub_out, pack], axis=0)
+            fill = tf.fill(pack_op, tf.constant(0, dtype=tf.float32))
+        concat = tf.concat([input1, fill], 0)
+        reshape = tf.reshape(concat, input3)
+        shape_tensor = tf.shape(reshape)
+        output = tf.strided_slice(shape_tensor, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+        return output
+
+    def KPFusedEmbeddingPaddingFast_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0")
+        input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+        input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2")
+        input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3")
+        pack = tf.constant(input['pack'], dtype=tf.int32)
+
+        feed_dict = {
+            input0: input['input0'],
+            input1: input['input1'],
+            input2: input['input2'],
+            input3: input['input3'],
+            pack: input['pack']
+        }
+
+        result = KPFusedEmbeddingPaddingFastOp(input0, input1, input2, input3, pack, len(input['input1'].shape))
+        return [result], feed_dict
+
+    def build_input_func(input_shape, pooling_shape, reshape):
+        input = {}
+        input["input0"] = np.array(input_shape).astype(np.int64)
+        input["input1"] = np.random.rand(*pooling_shape).astype(np.float32)
+        input["input2"] = pooling_shape[0]
+        input["input3"] = np.array(reshape).astype(np.int32)
+        input["pack"] = pooling_shape[1]
+        return input
+
+    def build_input_2D_case_1(meta):
+        return build_input_func((151 * 1, 10), (151 * 1, 10), (-1, 1510))
+
+    def build_input_2D_case_2(meta):
+        return build_input_func((151 * 1000, 10), (151 * 10, 10), (-1, 1510))
+
+    def build_input_2D_case_3(meta):
+        return build_input_func((2 * 1, 12), (2 * 1, 12), (-1, 24))
+
+    def build_input_2D_case_4(meta):
+        return build_input_func((2 * 1000, 12), (2 * 10, 12), (-1, 24))
+
+    def build_input_3D_case_1(meta):
+        return build_input_func((2 * 1000, 12), (2 * 10, 12, 5), (-1, 24))
+
+    return [
+        TestCase(
+            name="KPFusedEmbeddingPaddingFast_case_1",
+            op_fn=KPFusedEmbeddingPaddingFast_graph,
+            input_fn=build_input_2D_case_1,
+            fused_op_name="KPFusedEmbeddingPaddingFast",
+            start_op_name="PartitionedCall_1/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_1/StridedSlice_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=600,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPaddingFast_case_2",
+            op_fn=KPFusedEmbeddingPaddingFast_graph,
+            input_fn=build_input_2D_case_2,
+            fused_op_name="KPFusedEmbeddingPaddingFast",
+            start_op_name="PartitionedCall_3/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_3/StridedSlice_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=7000,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPaddingFast_case_3",
+            op_fn=KPFusedEmbeddingPaddingFast_graph,
+            input_fn=build_input_2D_case_3,
+            fused_op_name="KPFusedEmbeddingPaddingFast",
+            start_op_name="PartitionedCall_5/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_5/StridedSlice_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=600,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPaddingFast_case_4",
+            op_fn=KPFusedEmbeddingPaddingFast_graph,
+            input_fn=build_input_2D_case_4,
+            fused_op_name="KPFusedEmbeddingPaddingFast",
+            start_op_name="PartitionedCall_7/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_7/StridedSlice_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=800,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPadding_3D",
+            op_fn=KPFusedEmbeddingPaddingFast_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedEmbeddingPaddingFast",
+            is_fused=False,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPadding_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPadding_op.py
new file mode 100644
index 000000000..b20662a72
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedEmbeddingPadding_op.py
@@ -0,0 +1,121 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedEmbeddingPaddingOp(input0, input1, input2, input3, pack, dims):
+        cast = tf.cast(input0, tf.int32)
+        begin = tf.constant([0], dtype=tf.int32)
+        end = tf.constant([1], dtype=tf.int32)
+        strides = tf.constant([1], dtype=tf.int32)
+        hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1)
+        sub_out = hash_rows - input2
+
+        if dims == 3:
+            fill_shape = tf.stack([sub_out, pack, 5], axis=0)
+            fill = tf.fill(fill_shape, tf.constant(0, dtype=tf.float32))
+        else:
+            pack_op = tf.stack([sub_out, pack], axis=0)
+            fill = tf.fill(pack_op, tf.constant(0, dtype=tf.float32))
+        concat = tf.concat([input1, fill], 0)
+        output = tf.reshape(concat, input3)
+        return tf.concat([output, output], 1)
+
+    def KPFusedEmbeddingPadding_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0")
+        input1 = tf.compat.v1.placeholder(tf.float32, shape=input['input1'].shape, name="input1")
+        input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2")
+        input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3")
+        pack = tf.constant(input['pack'], dtype=tf.int32)
+
+        feed_dict = {
+            input0: input['input0'],
+            input1: input['input1'],
+            input2: input['input2'],
+            input3: input['input3'],
+            pack: input['pack']
+        }
+        
+        result = KPFusedEmbeddingPaddingOp(input0, input1, input2, input3, pack, len(input['input1'].shape))
+        return [result], feed_dict
+
+    def build_input_func(input_shape, pooling_shape, reshape):
+        input = {}
+        input["input0"] = np.array(input_shape).astype(np.int64)
+        input["input1"] = np.random.rand(*pooling_shape).astype(np.float32)
+        input["input2"] = pooling_shape[0]
+        input["input3"] = np.array(reshape).astype(np.int32)
+        input["pack"] = pooling_shape[1]
+        return input
+
+    def build_input_2D_case_1(meta):
+        return build_input_func((151 * 1, 10), (151 * 1, 10), (-1, 1510))
+
+    def build_input_2D_case_2(meta):
+        return build_input_func((151 * 1000, 10), (151 * 10, 10), (-1, 1510))
+
+    def build_input_2D_case_3(meta):
+        return build_input_func((2 * 1, 12), (2 * 1, 12), (-1, 24))
+
+    def build_input_2D_case_4(meta):
+        return build_input_func((2 * 1000, 12), (2 * 10, 12), (-1, 24))
+
+    def build_input_3D_case_1(meta):
+        return build_input_func((2 * 1000, 12), (2 * 10, 12, 5), (-1, 24))
+
+    return [
+        TestCase(
+            name="KPFusedEmbeddingPadding_case_1",
+            op_fn=KPFusedEmbeddingPadding_graph,
+            input_fn=build_input_2D_case_1,
+            fused_op_name="KPFusedEmbeddingPadding",
+            start_op_name="PartitionedCall_1/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_1/concat_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=400,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPadding_case_2",
+            op_fn=KPFusedEmbeddingPadding_graph,
+            input_fn=build_input_2D_case_2,
+            fused_op_name="KPFusedEmbeddingPadding",
+            start_op_name="PartitionedCall_3/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_3/concat_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=100,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPadding_case_3",
+            op_fn=KPFusedEmbeddingPadding_graph,
+            input_fn=build_input_2D_case_3,
+            fused_op_name="KPFusedEmbeddingPadding",
+            start_op_name="PartitionedCall_5/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_5/concat_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=450,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPadding_case_4",
+            op_fn=KPFusedEmbeddingPadding_graph,
+            input_fn=build_input_2D_case_4,
+            fused_op_name="KPFusedEmbeddingPadding",
+            start_op_name="PartitionedCall_7/ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_int64_Cast",
+            end_op_name="PartitionedCall_7/concat_1",
+            is_fused=True,
+            num_iters=500,
+            optimize_percent=700,
+        ),
+        TestCase(
+            name="KPFusedEmbeddingPadding_3D",
+            op_fn=KPFusedEmbeddingPadding_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedEmbeddingPadding",
+            is_fused=False,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedGather_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedGather_op.py
new file mode 100644
index 000000000..a20ffe5e6
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedGather_op.py
@@ -0,0 +1,103 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedGatherOp(data, slice_input, begin, dims):
+        if dims == 3:
+            end=[slice_input.shape.as_list()[0], begin[1] + 2, begin[2] + 2]
+            strides=[1, 1, 1]
+        else:
+            end=[slice_input.shape.as_list()[0], begin[1] + 2]
+            strides=[1, 1]
+        slice_out = tf.strided_slice(
+            slice_input,
+            begin=begin,
+            end=end,
+            strides=strides,
+            begin_mask=1,
+            end_mask=1,
+            shrink_axis_mask=2
+        )
+
+        if dims == 3:
+            value, indices = tf.unique(slice_out[0])
+        else:
+            value, indices = tf.unique(slice_out)
+        value_1, indices_1 = tf.unique(value)
+        gather1 = tf.gather(data, value_1)
+        gather2 = tf.gather(gather1, indices_1)
+        return value, indices, gather2
+
+    def KPFusedGather_graph(input, meta):
+        data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+        slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+        begin = tf.compat.v1.placeholder(tf.int32, shape=input['begin'].shape, name="begin")
+
+        feed = {
+            data: input['data'],
+            slice_input: input['slice_input'],
+            begin: input['begin']
+        }
+        shape, indices, data = KPFusedGatherOp(data, slice_input, begin, len(input['slice_input'].shape))
+        
+        return [shape, indices, data], feed
+
+    def KPFusedGather_check_fn(input_a, input_b, meta):
+        np.testing.assert_array_equal(input_a[0], input_b[0])
+        np.testing.assert_array_equal(input_a[1], input_b[1])
+        return np.allclose(input_a[2], input_b[2], rtol=1e-3, atol=1e-5)
+
+    def build_input_2D_case_1(meta):
+        input = {}
+        input["data"] = np.random.rand(50, 12).astype(np.float32)
+        input["slice_input"] = np.array([[10, 7], [20, 7], [30, 7]], dtype=np.int64)
+        input["begin"] = np.array([0, 1], dtype=np.int32)
+        return input
+
+    def build_input_3D_case_1(meta):
+        input = {}
+        input["data"] = np.random.rand(50, 12, 5).astype(np.float32)
+        input["slice_input"] = np.array([[10, 7], [20, 7], [30, 7]], dtype=np.int64)
+        input["begin"] = np.array([0, 1], dtype=np.int32)
+        return input
+
+    def build_input_3D_case_2(meta):
+        input = {}
+        input["data"] = np.random.rand(50, 12).astype(np.float32)
+        input["slice_input"] = np.array([[[10, 7], [20, 7], [30, 7]], [[10, 2], [20, 3], [30, 4]]], dtype=np.int64)
+        input["begin"] = np.array([0, 1, 1], dtype=np.int32)
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedGather_input_2D_case_1",
+            op_fn=KPFusedGather_graph,
+            input_fn=build_input_2D_case_1,
+            check_fn = KPFusedGather_check_fn,
+            fused_op_name = "KPFusedGather",
+            start_op_name = "PartitionedCall_1/strided_slice",
+            end_op_name = "PartitionedCall_1/GatherV2",
+            is_fused = True,
+            num_iters=500,
+            optimize_percent = 400,
+        ),
+        TestCase(
+            name="KPFusedGather_input_3D_case_1",
+            op_fn=KPFusedGather_graph,
+            input_fn=build_input_3D_case_1,
+            check_fn = KPFusedGather_check_fn,
+            fused_op_name = "KPFusedGather",
+            is_fused = False,
+        ),
+        TestCase(
+            name="KPFusedGather_input_3D_case_2",
+            op_fn=KPFusedGather_graph,
+            input_fn=build_input_3D_case_2,
+            check_fn = KPFusedGather_check_fn,
+            fused_op_name = "KPFusedGather",
+            is_fused = False,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseDynamicStitch_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseDynamicStitch_op.py
new file mode 100644
index 000000000..8eea425a0
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseDynamicStitch_op.py
@@ -0,0 +1,111 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedSparseDynamicStitchOp(x, emb_tables, num_tables):
+        x_1 = tf.reshape(x, shape=[-1])  # 将输入 x 展平成一维向量 x_1
+        group_ids = tf.math.floormod(x_1, num_tables)
+        group_ids = tf.cast(group_ids, dtype=np.int32)
+        chunk_indices = tf.math.floordiv(x_1, num_tables)
+        original_indices = tf.range(0, tf.size(x_1), 1)
+        a = tf.dynamic_partition(original_indices, group_ids, num_partitions=num_tables)
+        b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=num_tables)
+        c = [tf.gather(emb_tables[i], b[i]) for i in range(num_tables)]
+        d = tf.raw_ops.ParallelDynamicStitch(indices=a, data=c)
+        return d
+
+    def KPFusedSparseDynamicStitch_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        x_shape = input['x'].shape
+        lst = list(x_shape)    # 1. 转 list
+        lst[1] = None
+        x_shape = tuple(lst)
+        num_tables = meta["num_tables"]
+        x = tf.compat.v1.placeholder(tf.int64, shape=x_shape, name="x")
+        
+        emb_tables = []
+        feed_dict = {x: input['x']}
+        
+        for i in range(num_tables):
+            placeholder = tf.compat.v1.placeholder(tf.float32, shape=input[f'emb_table_{i}'].shape, name=f'emb_table_{i}')
+            emb_tables.append(placeholder)
+            feed_dict[placeholder] = input[f'emb_table_{i}']
+        
+        result = KPFusedSparseDynamicStitchOp(x, emb_tables, num_tables)
+        return result, feed_dict
+
+    def build_input_case_1(meta):
+        input = {}
+        num_tables = meta["num_tables"]
+        emb_dim = 10
+        max_val = float('inf')
+
+        for i in range(num_tables):
+            N = np.random.randint(1000000, 44739244)
+            max_val = min(N, max_val)
+            input[f'emb_table_{i}'] = np.random.rand(N, emb_dim).astype(np.float32)
+
+        input["x"] = np.random.randint(0, num_tables * max_val, size=(1000, num_tables)).astype(np.int32)  # 12个元素,每个元素范围0-143
+        return input
+
+    def build_input_case_2(meta):
+        input = {}
+        num_tables = meta["num_tables"]
+        emb_dim = 10
+        max_val = float('inf')
+
+        for i in range(num_tables):
+            N = np.random.randint(1000000, 44739244)
+            max_val = min(N, max_val)
+            input[f'emb_table_{i}'] = np.random.rand(N, emb_dim).astype(np.float32)
+
+        input["x"] = np.random.randint(0, num_tables * max_val, size=(10, 100, num_tables)).astype(np.int32)  # 12个元素,每个元素范围0-143
+        return input
+
+    def build_input_3D_case_1(meta):
+        input = {}
+        num_tables = meta["num_tables"]
+        emb_dim = 10
+        max_val = float('inf')
+
+        for i in range(num_tables):
+            N = np.random.randint(1000000, 44739244)
+            max_val = min(N, max_val)
+            input[f'emb_table_{i}'] = np.random.rand(N, emb_dim, 1).astype(np.float32)
+
+        input["x"] = np.random.randint(0, num_tables * max_val, size=(1000, num_tables)).astype(np.int32)  # 12个元素,每个元素范围0-143
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedSparseDynamicStitch_case_1",
+            op_fn=KPFusedSparseDynamicStitch_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseDynamicStitch",
+            start_op_name="PartitionedCall_1/Reshape",
+            end_op_name="PartitionedCall_1/ParallelDynamicStitch",
+            is_fused=True,
+            num_iters=200,
+            optimize_percent=50,
+            meta={"num_tables": 12},
+        ),
+        TestCase(
+            name="KPFusedSparseDynamicStitch_case_2",
+            op_fn=KPFusedSparseDynamicStitch_graph,
+            input_fn=build_input_case_2,
+            fused_op_name="KPFusedSparseDynamicStitch",
+            is_fused=True,
+            meta={"num_tables": 12},
+        ),
+        TestCase(
+            name="KPFusedSparseDynamicStitch_3D_case_1",
+            op_fn=KPFusedSparseDynamicStitch_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedSparseDynamicStitch",
+            is_fused=False,
+            meta={"num_tables": 12},
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseReshape_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseReshape_op.py
new file mode 100644
index 000000000..6cee67355
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseReshape_op.py
@@ -0,0 +1,132 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedSparseReshapeOp(slice_input, begin, newshape, pack_const, dims):
+        if dims == 3:
+            end = [0, 0, 2]
+            strides = [1, 1, 1]
+        else:
+            end = [0, 2]
+            strides = [1, 1]
+
+        slice67_out = tf.strided_slice(
+                slice_input,
+                begin=begin,
+                end=end,
+                strides=strides,
+                begin_mask=1,
+                end_mask=1,
+                shrink_axis_mask=2
+            )
+
+        slice67_out = tf.reshape(slice67_out, [-1, 1])
+        shape_out = tf.shape(slice_input)
+        slice57_out = tf.strided_slice(
+            shape_out, 
+            begin=[0],
+            end=[1],
+            strides=[1],
+            shrink_axis_mask=1
+        )
+        
+        input_shape = tf.stack([slice57_out, pack_const])
+        input_shape = tf.cast(input_shape, tf.int64)
+
+        range_out = tf.range(0, slice57_out, 1)
+        range_out = tf.reshape(range_out, [-1, 1])
+        range_out_64 = tf.cast(range_out, dtype=tf.int64)
+        concat_out = tf.concat([range_out_64, slice67_out], axis=-1)
+        
+        values = np.arange(slice_input.shape[0], dtype=np.float32)
+        
+        sparse_tensor = tf.SparseTensor(
+            indices=concat_out,
+            values=values,
+            dense_shape=input_shape
+        )
+        sparse_tensor_out = tf.sparse.reshape(sparse_tensor, newshape)
+        return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out
+
+    def KPFusedSparseReshape_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        slice_shape = input['slice_input'].shape
+        lst = list(slice_shape)    # 1. 转 list
+        lst[-1] = None
+        slice_shape = tuple(lst)
+        slice_input = tf.compat.v1.placeholder(tf.int64, shape=slice_shape, name="slice_input")
+        begin = tf.compat.v1.placeholder(tf.int32, shape=input['begin'].shape, name="begin")
+        newshape = tf.compat.v1.placeholder(tf.int64, shape=input['newshape'].shape, name="newshape")
+        pack_const = tf.compat.v1.placeholder(tf.int32, shape=(), name="pack_const")
+
+        feed_dict = {
+            slice_input: input['slice_input'],
+            begin: input['begin'],
+            newshape: input['newshape'],
+            pack_const: input['pack_const']
+        }
+        
+        result = KPFusedSparseReshapeOp(slice_input, begin, newshape, pack_const, len(slice_shape))
+        return result, feed_dict
+
+    def build_input_case_1(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["slice_input"] = np.array([[0, 0], [0, 1], [1, 2], [3, 4]]).astype(np.int64)
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["newshape"] = np.array([2, 4]).astype(np.int64)
+        input["pack_const"] = 2
+        return input
+
+    def build_input_case_2(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["slice_input"] = np.array([[0, 1]]).astype(np.int64)
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["newshape"] = np.array([-1, 1]).astype(np.int64)
+        input["pack_const"] = 1
+        return input
+
+    def build_input_3D_case_1(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["slice_input"] = np.array([[[0, 0], [0, 1], [1, 2], [3, 4]], [[1, 2], [2, 3], [3, 4], [4, 4]]]).astype(np.int64)
+        input["begin"] = np.array([0, 0, 1]).astype(np.int32)
+        input["newshape"] = np.array([2, 2]).astype(np.int64)
+        input["pack_const"] = 2
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedSparseReshape_case_1",
+            op_fn=KPFusedSparseReshape_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseReshape",
+            start_op_name="PartitionedCall_1/StridedSlice",
+            end_op_name="PartitionedCall_1/SparseReshape",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=400,
+        ),
+        TestCase(
+            name="KPFusedSparseReshape_case_2",
+            op_fn=KPFusedSparseReshape_graph,
+            input_fn=build_input_case_2,
+            fused_op_name="KPFusedSparseReshape",
+            start_op_name="PartitionedCall_3/StridedSlice",
+            end_op_name="PartitionedCall_3/SparseReshape",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=800,
+        ),
+        TestCase(
+            name="KPFusedSparseReshape_3D_case_1",
+            op_fn=KPFusedSparseReshape_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedSparseReshape",
+            is_fused=False,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduceNonzero_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduceNonzero_op.py
new file mode 100644
index 000000000..5ff3c648a
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduceNonzero_op.py
@@ -0,0 +1,203 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedSparseSegmentReduceNonzeroOp(data, indices, slice_input, begin, end, strides, is_mean):
+        shrink_axis_mask = 2 ** len(begin) - 2 #slice_out需要仅保留一维
+        slice_out = tf.strided_slice(
+                    slice_input,
+                    begin= begin,
+                    end= end,
+                    strides= strides,
+                    begin_mask=1,
+                    end_mask=1,
+                    shrink_axis_mask=shrink_axis_mask
+                )
+
+        segment_ids = tf.cast(slice_out, dtype=tf.int32)
+
+        if is_mean:
+            sparseseg_out = tf.sparse.segment_mean(
+                    data = data,
+                    indices = indices,
+                    segment_ids= segment_ids
+                )
+        else:
+            sparseseg_out = tf.sparse.segment_sum(
+                    data = data,
+                    indices = indices,
+                    segment_ids= segment_ids
+                )
+        zero = tf.zeros_like(sparseseg_out)
+        notequal = tf.not_equal(x=sparseseg_out, y = zero)
+        where_out = tf.where(notequal)
+        output_shape = tf.cast(where_out, dtype=tf.int32)
+        output_data = tf.gather_nd(params=sparseseg_out, indices=where_out)
+        shape = tf.shape(sparseseg_out, out_type=tf.int64)
+        output_ids = tf.cast(shape, dtype=tf.int32)
+
+        return output_shape, output_ids, output_data
+
+    def KPFusedSparseSegmentReduceNonzero_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        indices_type = np.int32
+        if "indices_type" in meta.keys():
+            indices_type = meta["indices_type"]
+        data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+        indices = tf.compat.v1.placeholder(indices_type, shape=input['indices'].shape, name="indices")
+        slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+        begin = tf.constant(input['begin'], dtype=tf.int32)
+        end = tf.constant(input['end'], dtype=tf.int32)
+        strides = tf.compat.v1.placeholder(tf.int32, shape=input['strides'].shape, name="strides")
+        is_mean = meta['is_mean']
+
+        feed_dict = {
+            data: input['data'],
+            indices: input['indices'],
+            slice_input: input['slice_input'],
+            begin: input['begin'],
+            end: input['end'],
+            strides: input['strides']
+        }
+        
+        result = KPFusedSparseSegmentReduceNonzeroOp(data, indices, slice_input, begin, end, strides, is_mean)
+        return result, feed_dict
+
+    def build_input_case_1(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        indices_type = np.int32
+        if "indices_type" in meta.keys():
+            indices_type = meta["indices_type"]
+        data = np.random.rand(1449).astype(np.float32) * 10
+        zero_prob = 0.3
+        mask = np.random.rand(1449) > zero_prob
+        data[~mask] = 0
+        input["data"] = data
+        input["indices"] = np.random.randint(0, 1449, size=5742, dtype=indices_type)
+                
+        start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+        end_points = start_points + np.random.randint(1, 100, size=5742)
+        end_points = np.minimum(end_points, 15661)
+        slice_input = np.column_stack((start_points, end_points))
+        slice_input[:, 1] = slice_input[:, 0]
+        input["slice_input"] = slice_input
+
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2]).astype(np.int32)
+        input["strides"] = np.array([1, 2]).astype(np.int32)
+        return input
+
+    def build_input_3D_case_1(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        data = np.random.rand(1449, 2, 3).astype(np.float32) * 10
+        zero_prob = 0.3
+        mask = np.random.rand(1449, 2, 3) > zero_prob
+        data[~mask] = 0
+        input["data"] = data
+        input["indices"] = np.random.randint(0, 1449, size=5742, dtype=np.int32)
+                
+        start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+        end_points = start_points + np.random.randint(1, 100, size=5742)
+        end_points = np.minimum(end_points, 15661)
+        slice_input = np.column_stack((start_points, end_points))
+        slice_input[:, 1] = slice_input[:, 0]
+        input["slice_input"] = slice_input
+
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2]).astype(np.int32)
+        input["strides"] = np.array([1, 2]).astype(np.int32)
+        return input
+
+    def build_input_3D_case_2(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        data = np.random.rand(1449).astype(np.float32) * 10
+        zero_prob = 0.3
+        mask = np.random.rand(1449) > zero_prob
+        data[~mask] = 0
+        input["data"] = data
+        input["indices"] = np.random.randint(0, 1449, size=5742, dtype=np.int32)
+
+        start_points = np.sort(np.random.choice(np.arange(0, 15660), size=5742, replace=False))
+        end_points = start_points + np.random.randint(1, 100, size=5742)
+        end_points = np.minimum(end_points, 15661)
+        slice_input = np.column_stack((start_points, end_points))
+        slice_input[:, 1] = slice_input[:, 0]
+        input["slice_input"] = np.stack([slice_input, slice_input], axis=2) #3D数据
+
+        input["begin"] = np.array([0, 1, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2, 1]).astype(np.int32)
+        input["strides"] = np.array([1, 2, 1]).astype(np.int32)
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_sum_case_1",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            start_op_name="PartitionedCall_1/StridedSlice",
+            end_op_name="PartitionedCall_1/GatherNd",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=30,
+            meta = {"is_mean": False}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_mean_case_1",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            start_op_name="PartitionedCall_3/StridedSlice",
+            end_op_name="PartitionedCall_3/GatherNd",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=30,
+            meta = {"is_mean": True}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_sum_case_1_int64",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            start_op_name="PartitionedCall_5/StridedSlice",
+            end_op_name="PartitionedCall_5/GatherNd",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=30,
+            meta = {"is_mean": False, "indices_type": np.int64}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_mean_case_1_int64",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            start_op_name="PartitionedCall_7/StridedSlice",
+            end_op_name="PartitionedCall_7/GatherNd",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=30,
+            meta = {"is_mean": True, "indices_type": np.int64}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_sum_3D_case_1",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            is_fused=False,
+            meta = {"is_mean": False}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduceNonzero_sum_3D_case_2",
+            op_fn=KPFusedSparseSegmentReduceNonzero_graph,
+            input_fn=build_input_3D_case_2,
+            fused_op_name="KPFusedSparseSegmentReduceNonzero",
+            is_fused=False,
+            meta = {"is_mean": True}
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduce_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduce_op.py
new file mode 100644
index 000000000..ecc67e290
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSegmentReduce_op.py
@@ -0,0 +1,165 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedSparseSegmentReduceOp(data, indices, slice_input, begin, end, strides, is_mean):
+        shrink_axis_mask = 2 ** len(begin) - 2 #slice_out需要仅保留一维
+        slice_out = tf.strided_slice(
+                slice_input,
+                begin=begin,
+                end=end,
+                strides=strides,
+                begin_mask=1,
+                end_mask=1,
+                shrink_axis_mask=shrink_axis_mask
+            )
+
+        segment_ids = tf.cast(slice_out, dtype=tf.int32)
+        if is_mean:
+            output = tf.sparse.segment_mean(
+                data=data,
+                indices=indices,
+                segment_ids=segment_ids
+            )
+        else:
+            output = tf.sparse.segment_sum(
+                data=data,
+                indices=indices,
+                segment_ids=segment_ids
+            )
+        
+        output_shape = tf.shape(output)
+        slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1], shrink_axis_mask=1)
+        
+        return output, slice_out
+
+    def KPFusedSparseSegmentReduce_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        indices_type = np.int32
+        if "indices_type" in meta.keys():
+            indices_type = meta["indices_type"]
+        data = tf.compat.v1.placeholder(tf.float32, shape=input['data'].shape, name="data")
+        indices = tf.compat.v1.placeholder(indices_type, shape=input['indices'].shape, name="indices")
+        slice_input = tf.compat.v1.placeholder(tf.int64, shape=input['slice_input'].shape, name="slice_input")
+        begin = tf.constant(input['begin'], dtype=tf.int32)
+        end = tf.constant(input['end'], dtype=tf.int32)
+        strides = tf.compat.v1.placeholder(tf.int32, shape=input['strides'].shape, name="strides")
+        is_mean = meta['is_mean']
+
+        feed_dict = {
+            data: input['data'],
+            indices: input['indices'],
+            slice_input: input['slice_input'],
+            begin: input['begin'],
+            end: input['end'],
+            strides: input['strides']
+        }
+        
+        result = KPFusedSparseSegmentReduceOp(data, indices, slice_input, begin, end, strides, is_mean)
+        return result, feed_dict
+
+    def build_input_case_1(meta):
+        input = {}
+        indices_type = np.int32
+        if "indices_type" in meta.keys():
+            indices_type = meta["indices_type"]
+        # 根据核心函数入参名生成输入数据
+        input["data"] = np.random.rand(4, 3).astype(np.float32)
+        input["indices"] = np.array([0, 1, 2]).astype(indices_type)
+        input["slice_input"] = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) 
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2]).astype(np.int32)
+        input["strides"] = np.array([1, 2]).astype(np.int32)
+        return input
+
+    def build_input_3D_case_1(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["data"] = np.random.rand(4, 3, 2).astype(np.float32)
+        input["indices"] = np.array([0, 1, 2]).astype(np.int32)
+        input["slice_input"] = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64)
+        input["begin"] = np.array([0, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2]).astype(np.int32)
+        input["strides"] = np.array([1, 2]).astype(np.int32)
+        return input
+
+    def build_input_3D_case_2(meta):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["data"] = np.random.rand(4, 3).astype(np.float32)
+        input["indices"] = np.array([0, 1]).astype(np.int32)
+        input["slice_input"] = np.array([[[0, 0], [0, 2], [1, 2]], [[0, 1], [1, 2], [2, 2]]], dtype=np.int64)
+        input["begin"] = np.array([0, 1, 1]).astype(np.int32)
+        input["end"] = np.array([0, 2, 1]).astype(np.int32)
+        input["strides"] = np.array([1, 2, 1]).astype(np.int32)
+        return input
+
+    return [
+        TestCase(
+            name="KPFusedSparseSegmentReduce_sum_case_1",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            start_op_name="PartitionedCall_1/StridedSlice",
+            end_op_name="PartitionedCall_1/StridedSlice_1",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=200,
+            meta = {"is_mean": False}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduce_mean_case_1",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            start_op_name="PartitionedCall_3/StridedSlice",
+            end_op_name="PartitionedCall_3/StridedSlice_1",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=200,
+            meta = {"is_mean": True}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduce_sum_case_1_int64",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            start_op_name="PartitionedCall_5/StridedSlice",
+            end_op_name="PartitionedCall_5/StridedSlice_1",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=200,
+            meta = {"is_mean": False, "indices_type": np.int64}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduce_mean_case_1_int64",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            start_op_name="PartitionedCall_7/StridedSlice",
+            end_op_name="PartitionedCall_7/StridedSlice_1",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=200,
+            meta = {"is_mean": True, "indices_type": np.int64}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduce_sum_3D_case_1",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_3D_case_1,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            is_fused=False,
+            meta = {"is_mean": False}
+        ),
+        TestCase(
+            name="KPFusedSparseSegmentReduce_sum_3D_case_2",
+            op_fn=KPFusedSparseSegmentReduce_graph,
+            input_fn=build_input_3D_case_2,
+            fused_op_name="KPFusedSparseSegmentReduce",
+            is_fused=False,
+            meta = {"is_mean": True}
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSelect_op.py b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSelect_op.py
new file mode 100644
index 000000000..10eda2d61
--- /dev/null
+++ b/tensorflow/python/grappler/embedding_fused_test/ops/KPFusedSparseSelect_op.py
@@ -0,0 +1,95 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    @tf.function
+    def KPFusedSparseSelectOp(input_a, input_b, input_c, greater, equal1, equal2, equal3):
+        a = tf.reshape(input_a, [-1, 1])
+        b = tf.reshape(input_b, [-1, 1])
+        c = tf.reshape(input_c, [-1, 1])
+        output_x = a
+
+        greater_a = tf.greater(a, greater)
+        shape_reshape_a1 = tf.shape(a)
+        fill_a1 = tf.fill(shape_reshape_a1, tf.constant(1, dtype=tf.float32))
+        realdiv = tf.realdiv(fill_a1, tf.constant(1, dtype=tf.float32))
+        cast_a = tf.cast(greater_a, tf.float32)
+        shape_a = tf.shape(cast_a)
+        fill_a = tf.fill(shape_a, tf.constant(1, dtype=tf.float32))
+        equal_4563 = tf.equal(b, equal1)
+        equal_10831 = tf.equal(b, equal2)
+        equal_3 = tf.equal(c, equal3)
+        select_1 = tf.where(equal_4563, fill_a, cast_a)
+        select_2 = tf.where(equal_10831, fill_a, select_1)
+        output_y = select_2
+        select_3 = tf.where(equal_3, realdiv, fill_a1)
+        output_z = tf.concat([select_2, select_3], axis=-1)
+        return output_x, output_y, output_z
+
+    def KPFusedSparseSelect_graph(input, meta):
+        # 根据核心函数入参名创建placeholder
+        input_a = tf.compat.v1.placeholder(tf.int32, shape=input['input_a'].shape, name="input_a")
+        input_b = tf.compat.v1.placeholder(tf.int32, shape=input['input_b'].shape, name="input_b")
+        input_c = tf.compat.v1.placeholder(tf.int32, shape=input['input_c'].shape, name="input_c")
+        greater = tf.constant(input['greater'], dtype=tf.int32)
+        equal1 = tf.constant(input['equal1'], dtype=tf.int32)
+        equal2 = tf.constant(input['equal2'], dtype=tf.int32)
+        equal3 = tf.constant(input['equal3'], dtype=tf.int32)
+
+        feed_dict = {
+            input_a: input['input_a'],
+            input_b: input['input_b'],
+            input_c: input['input_c'],
+            greater: input['greater'],
+            equal1: input['equal1'],
+            equal2: input['equal2'],
+            equal3: input['equal3']
+        }
+        
+        result = KPFusedSparseSelectOp(input_a, input_b, input_c, greater, equal1, equal2, equal3)
+        return result, feed_dict
+
+    def build_input_case(a_shape, b_shape, c_shape):
+        input = {}
+        # 根据核心函数入参名生成输入数据
+        input["input_a"] = np.random.randint(0, 100, size=a_shape).astype(np.int32)
+        input["input_b"] = np.random.randint(0, 100, size=b_shape).astype(np.int32)
+        input["input_c"] = np.random.randint(0, 100, size=c_shape).astype(np.int32)
+        input["greater"] = np.array(0, dtype=np.int32)
+        input["equal1"] = np.array(4563, dtype=np.int32)
+        input["equal2"] = np.array(10831, dtype=np.int32)
+        input["equal3"] = np.array(3, dtype=np.int32)
+        return input
+
+    def build_input_case_1(meta):
+        return build_input_case((100, 10), (10, 100), (20, 50))
+
+    def build_input_case_2(meta):
+        return build_input_case((50, 50, 50), (50, 50, 50), (50, 50, 50))
+
+    return [
+        TestCase(
+            name="KPFusedSparseSelect_case_1",
+            op_fn=KPFusedSparseSelect_graph,
+            input_fn=build_input_case_1,
+            fused_op_name="KPFusedSparseSelect",
+            start_op_name="PartitionedCall_1/Reshape",
+            end_op_name="PartitionedCall_1/concat",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=400,
+        ),
+        TestCase(
+            name="KPFusedSparseSelect_case_2",
+            op_fn=KPFusedSparseSelect_graph,
+            input_fn=build_input_case_2,
+            fused_op_name="KPFusedSparseSelect",
+            start_op_name="PartitionedCall_3/Reshape",
+            end_op_name="PartitionedCall_3/concat",
+            is_fused=True,
+            num_iters=1000,
+            optimize_percent=600,
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD
index b3a3d612a..185da7da4 100644
--- a/tensorflow/python/ops/BUILD
+++ b/tensorflow/python/ops/BUILD
@@ -228,6 +228,14 @@ tf_gen_op_strict_wrapper_private_py(
     ],
 )
 
+tf_gen_op_strict_wrapper_private_py(
+    name = "embedding_fused_ops_gen",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/core:embedding_fused_ops_op_lib",
+    ],
+)
+
 tf_gen_op_strict_wrapper_private_py(
     name = "collective_ops_gen",
     visibility = ["//tensorflow:internal"],