diff --git a/.bazelrc b/.bazelrc
index d18b2f8..6a7face 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -223,6 +223,14 @@ build:mkl_aarch64 -c opt
 build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
 build:mkl_aarch64_threadpool -c opt
 
+# Config setting to build ktfop.
+build:ktfop --define=build_with_ktfop=true
+build:ktfop --define=build_with_kblas=true
+build:ktfop -c opt
+
+# Config setting to build fused_embedding.
+build:fused_embedding --define=build_with_fused_embedding=true
+
 # CUDA: This config refers to building CUDA op kernels with nvcc.
 build:cuda --repo_env TF_NEED_CUDA=1
 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
diff --git a/.gitignore b/.gitignore
index cebef4f..b798f79 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,6 +19,7 @@ tensorflow/contrib/cmake/_build/
 /build/
 [Bb]uild/
 /build_output/
+/test_output
 /tensorflow/core/util/version_info.cc
 /tensorflow/python/framework/fast_tensor_util.cpp
 /tensorflow/lite/gen/**
diff --git a/tensorflow/build.sh b/build.sh
new file mode 100644
index 0000000..c4e3b0f
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,130 @@
+#/bin/bash
+set -ex
+
+TENSORFLOW_DIR=""
+ENABLE_GCC12=false
+ENABLE_KDNN=false
+KDNN_OPTIONS=""
+
+usage() {
+    echo "Usage: $0 [--features <feature1,feature2>]"
+    echo "Example: $0 --features gcc12,kdnn"
+    echo "Notes: --features gcc12 is only suitable for openeuler 22.03 to set gcc12 insdead of default gcc10"
+    exit 1
+}
+
+if [ -f /etc/os-release ]; then
+    source /etc/os-release
+fi
+
+echo "current os: $NAME $VERSION_ID"
+
+case "$VERSION_ID" in
+    "22.03")
+        echo "config gcc12 path"
+        GCC12_PATH=/opt/openEuler/gcc-toolset-12/root/usr/bin/
+        GCC12_LD_LIBRARY_PATH=/opt/openEuler/gcc-toolset-12/root/usr/lib64
+        ;;
+    "24.03")
+        echo "use default gcc"
+        ;;
+    *)
+        echo "unsupported os version: $VERSION_ID"
+        exit 1
+        ;;
+esac
+
+# 定义目标软链接路径
+TARGET_LINK="/usr/bin/aarch64-linux-gnu-gcc"
+
+# 检查文件或链接是否已经存在
+if [ ! -f "$TARGET_LINK" ] && [ ! -L "$TARGET_LINK" ]; then
+    echo "未检测到 $TARGET_LINK,正在创建软链接..."
+    
+    # 获取系统自带 gcc 的路径
+    GCC_PATH=$(which gcc)
+    
+    if [ -n "$GCC_PATH" ]; then
+        # 使用 sudo 权限创建链接(如果是 root 用户可去掉 sudo)
+        ln -s "$GCC_PATH" "$TARGET_LINK"
+        echo "软链接创建成功: $TARGET_LINK -> $GCC_PATH"
+    else
+        echo "错误: 系统未安装 gcc,请先运行 yum install gcc"
+        exit 1
+    fi
+else
+    echo "检测到 $TARGET_LINK 已存在,跳过创建步骤。"
+fi
+
+while [[ "$#" -gt 0 ]]; do
+    case "$1" in
+        --features)
+            if [[ -z "$2" ]]; then
+                echo "Error: --features requires a value"
+                usage
+            fi
+            IFS=',' read -ra features_array <<< "$2"
+            for feature in "${features_array[@]}"; do
+                case "$feature" in
+                    "gcc12")
+                        ENABLE_GCC12=true
+                        ;;
+                    "kdnn")
+                        ENABLE_KDNN=true
+                        ;;
+                    *) 
+                        echo "Warning: Unknown feature '$feature', ignoring"
+                        ;;
+                esac
+            done
+            shift 2
+            ;;
+        -h|--help)
+            usage
+            ;;
+        *)
+            echo "Unknown parameter: $1"
+            usage
+            ;;
+    esac
+done
+
+TENSORFLOW_ROOT=$(pwd)
+DIST_DIR=$TENSORFLOW_ROOT/download
+PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin
+
+export PATH=$BAZEL_PATH:$PATH
+DIST_DIR="${DISTDIR:-$DIST_DIR}"
+BAZEL_COMPILE_CACHE="${BUILD_CACHE_DIR:-$TENSORFLOW_ROOT/output}"
+
+if ! command -v bazel &> /dev/null; then
+    echo "Error: Bazel is not installed. Please install Bazel and try again."
+    exit 1
+fi
+
+bazel version
+
+if [ "$ENABLE_GCC12" == true ]; then
+    export PATH=$GCC12_PATH:$PATH
+    export LD_LIBRARY_PATH=$GCC12_LD_LIBRARY_PATH
+    GCC_VERSION=$(gcc -dumpversion | cut -d. -f1)
+    if [[ "$GCC_VERSION" != "12" ]]; then
+        echo "Error: GCC version is $GCC_VERSION. Please install GCC 12. Consider use command: yum install gcc-toolset-12-gcc*"
+        exit 1
+    fi
+fi
+
+if [ "$ENABLE_KDNN" == true ]; then
+    KDNN_OPTIONS="--define=enable_kdnn=true"
+fi
+
+gcc --version
+cd $TENSORFLOW_ROOT && \
+PATH=$PATH \
+LD_LIBRARY_PATH=$LD_LIBRARY_PATH \
+bazel --output_user_root=$BAZEL_COMPILE_CACHE build --distdir=$DIST_DIR \
+--host_copt=-march=armv8.3-a --copt=-march=armv8.3-a --define with_default_optimizations=true \
+--copt=-Wno-sign-compare --config=v2 --config=noaws \
+$KDNN_OPTIONS \
+//tensorflow/tools/pip_package:build_pip_package
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package ./output-release
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 8e39908..efd7d24 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -327,6 +327,7 @@ StatusOr<xla::ExecutionOutput> RunExecutable(
   run_options.set_allocator(allocator);
   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
   run_options.set_rng_seed(GetXLARandomSeed());
+  run_options.set_run_in_tf_kernel(ctx->executor_policy() == ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR);
 
   StatusOr<xla::ExecutionOutput> execution_output;
   bool run_synchronous =
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5385743..b916623 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -111,6 +111,14 @@ load(
     "//third_party/mkl:build_defs.bzl",
     "if_mkl",
 )
+load(
+    "//third_party/ktfop:build_defs.bzl",
+    "if_ktfop",
+)
+load(
+    "//third_party/fused_embedding:build_defs.bzl",
+    "if_fused_embedding",
+)
 load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
 
 package(
@@ -650,6 +658,11 @@ cc_library(
         "//tensorflow/core/kernels/mkl:mkl_matmul_op",
         "//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
         "//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
+    ]) + if_ktfop([
+        "//tensorflow/core/kernels/ktfop:fused_embedding_ops",
+        "//tensorflow/core/kernels/ktfop:softmax_ops",
+    ]) + if_fused_embedding([
+        "//tensorflow/core/kernels/fused_embedding:fused_embedding_ops",
     ]) + if_cuda_or_rocm([
         "//tensorflow/core/kernels:cudnn_rnn_kernels",
     ]) + if_cuda([
diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index 2c625ec..1012ffb 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -638,6 +638,18 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "kernel_stat",
+    hdrs = ["kernel_stat.h"],
+    copts = tf_copts(),
+    deps = [
+        ":graph_view",
+        "//tensorflow/core:graph",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
 cc_library(
     name = "executor",
     srcs = ["executor.cc"],
@@ -649,6 +661,7 @@ cc_library(
         ":entry",
         ":executor_factory",
         ":graph_view",
+        ":kernel_stat",
         ":immutable_executor_state",
         ":local_executor_params",
         ":pending_counts",
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index db5984c..db4089d 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -327,6 +327,12 @@ DirectSession::DirectSession(const SessionOptions& options,
       factory_(factory),
       cancellation_manager_(new CancellationManager()),
       operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
+
+  if (options_.config.use_batch_op_scheduling()) {
+    use_batch_scheduling_executor_ = true;
+    LOG(INFO) << "enable batch scheduling executor";
+  }
+
   const int thread_pool_size =
       options_.config.session_inter_op_thread_pool_size();
   if (thread_pool_size > 0) {
@@ -679,7 +685,11 @@ Status DirectSession::RunInternal(
   args.run_all_kernels_inline = pool == nullptr;
   args.start_time_usecs = start_time_usecs;
   args.deadline = deadline;
-
+  if (use_batch_scheduling_executor_) {
+    args.executor_policy = ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR;
+  } else {
+    args.executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
+  }
   const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
 
   bool update_cost_model = false;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index a81a307..2f7e8ca 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -437,6 +437,7 @@ class DirectSession : public Session {
   // Otherwise run in global thread pool, session owned thread pool or handler
   // pool according to other specifications of RunOptions and ConfigProto.
   bool run_in_caller_thread_ = false;
+  bool use_batch_scheduling_executor_ = false;
 
   DirectSession(const DirectSession&) = delete;
   void operator=(const DirectSession&) = delete;
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index a539a6e..9807322 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/executor_factory.h"
 #include "tensorflow/core/common_runtime/graph_view.h"
 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
+#include "tensorflow/core/common_runtime/kernel_stat.h"
 #include "tensorflow/core/common_runtime/pending_counts.h"
 #include "tensorflow/core/common_runtime/propagator_state.h"
 #include "tensorflow/core/common_runtime/renamed_device.h"
@@ -155,76 +156,21 @@ class ExecutorImpl : public Executor {
     return OkStatus();
   }
 
+  ImmutableExecutorState& GetImmutableState() {
+    return immutable_state_;
+  }
+
+  ExecutorInternal::KernelStats* GetKernelStat() {
+    return &kernel_stats_;
+  }
  private:
   void RunAsyncInternal(const Args& args, DoneCallback done) override;
 
   template <class PropagatorStateType>
   friend class ExecutorState;
 
-  // Stores execution time information about the kernels in an executor's graph.
-  class KernelStats {
-   public:
-    KernelStats() = default;
-
-    void Initialize(const GraphView& gview) {
-      is_expensive_.resize(gview.num_nodes());
-      cost_estimates_ =
-          std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
-      for (int32_t i = 0; i < gview.num_nodes(); ++i) {
-        if (gview.node(i)) {
-          is_expensive_[i] =
-              gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
-          cost_estimates_[i] = kInitialCostEstimateCycles;
-        }
-      }
-    }
-
-    // Returns true iff the given node is considered "expensive". The
-    // executor uses this flag to optimize graph execution, for example
-    // by "inlining" inexpensive kernels.
-    bool IsExpensive(const NodeItem& node) const {
-      return is_expensive_[node.node_id] &&
-             (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
-              kOpIsExpensiveThresholdCycles);
-    }
-
-    // Returns the value of kernel->IsExpensive().
-    bool HasExpensiveMarker(const NodeItem& node) const {
-      return is_expensive_[node.node_id];
-    }
-
-    // Updates the dynamic cost estimate, which is used to determine whether the
-    // given node is expensive. The new cost estimate is a weighted average of
-    // the old cost estimate and the latest cost. We only update cost estimates
-    // for kernels for which IsExpensive() return true.
-    void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
-      // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
-      // updates may result in one or more updates being ignored.  This does not
-      // affect correctness but may slow down the update frequency.
-      std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
-      auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
-
-      uint64 new_estimate =
-          ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
-
-      cost_estimate.store(new_estimate, std::memory_order_relaxed);
-    }
-
-   private:
-    // Initial time (in CPU cycles) we expect an operation to take.  Used to
-    // determine whether an operation should be place in a threadpool.
-    // Operations start out "expensive".
-    static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
-    static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
-    static constexpr uint64 kCostDecay = 10;
-
-    std::vector<bool> is_expensive_;
-    // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
-    std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
-  };
-
   ImmutableExecutorState immutable_state_;
-  KernelStats kernel_stats_;
+  ExecutorInternal::KernelStats kernel_stats_;
 
   ExecutorImpl(const ExecutorImpl&) = delete;
   void operator=(const ExecutorImpl&) = delete;
@@ -284,12 +230,12 @@ class ExecutorState {
  public:
   ExecutorState(const Executor::Args& args,
                 const ImmutableExecutorState& immutable_state_,
-                ExecutorImpl::KernelStats* kernel_stats_);
+                ExecutorInternal::KernelStats* kernel_stats_);
   ~ExecutorState();
 
   void RunAsync(Executor::DoneCallback done);
 
- private:
+ protected:
   // Use `TaggedNode` types defined by `PropagatorStateType`.
   typedef typename PropagatorStateType::TaggedNode TaggedNode;
   typedef
@@ -338,7 +284,7 @@ class ExecutorState {
   // This method will clear `*ready` before returning.
   //
   // REQUIRES: `!ready->empty()`.
-  void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
+  virtual void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
 
   // A wrapper for runner_ to keep track of the pending queue length. Op
   // execution should dispatch work using this function instead of using runner_
@@ -388,7 +334,7 @@ class ExecutorState {
   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
   CallFrameInterface* call_frame_;
   const ImmutableExecutorState& immutable_state_;
-  ExecutorImpl::KernelStats* const kernel_stats_;
+  ExecutorInternal::KernelStats* const kernel_stats_;
   CancellationManager* cancellation_manager_;
   tsl::CoordinationServiceAgent* coordination_service_agent_;
   absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt;
@@ -397,6 +343,7 @@ class ExecutorState {
   Executor::Args::Runner runner_;
   bool sync_on_finish_;
   const bool run_all_kernels_inline_;
+  ExecutorPolicy executor_policy_ = ExecutorPolicy::USE_NORMAL_EXECUTOR;
 
   PropagatorStateType propagator_;
 
@@ -418,7 +365,7 @@ class ExecutorState {
 template <class PropagatorStateType>
 ExecutorState<PropagatorStateType>::ExecutorState(
     const Executor::Args& args, const ImmutableExecutorState& immutable_state,
-    ExecutorImpl::KernelStats* kernel_stats)
+    ExecutorInternal::KernelStats* kernel_stats)
     : vlog_(VLOG_IS_ON(1)),
       log_memory_(LogMemory::IsEnabled()),
       step_id_(args.step_id),
@@ -446,6 +393,7 @@ ExecutorState<PropagatorStateType>::ExecutorState(
       runner_(args.runner),
       sync_on_finish_(args.sync_on_finish),
       run_all_kernels_inline_(args.run_all_kernels_inline),
+      executor_policy_(args.executor_policy),
       propagator_(immutable_state, step_id_, vlog_),
       num_outstanding_ops_(0) {
   if (args.user_intra_op_threadpool != nullptr) {
@@ -463,6 +411,42 @@ ExecutorState<PropagatorStateType>::~ExecutorState() {
   delete slice_reader_cache_;
 }
 
+template <class PropagatorStateType>
+class BatchSchedulingExecutorState : public ExecutorState<PropagatorStateType> {
+  public:
+    BatchSchedulingExecutorState(
+      const Executor::Args& args, const ImmutableExecutorState& immutable_state_,
+      ExecutorInternal::KernelStats* kernel_stats_)
+      : ExecutorState<PropagatorStateType>(args, immutable_state_, kernel_stats_) {}
+    ~BatchSchedulingExecutorState() {}
+
+  protected:
+    typedef typename PropagatorStateType::TaggedNode TaggedNode;
+    typedef
+        typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
+    typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
+
+    virtual void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
+};
+
+class ExecutorStateFactory {
+  public:
+    template <class PropagatorStateType>
+    static ExecutorState<PropagatorStateType>* Create(
+      const Executor::Args& args, ExecutorImpl* impl) {
+      ImmutableExecutorState& immutable_state = impl->GetImmutableState();
+      ExecutorInternal::KernelStats* kernel_stats = impl->GetKernelStat();
+
+      if (args.executor_policy == ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR) {
+        return new BatchSchedulingExecutorState<PropagatorStateType>(
+          args, immutable_state, kernel_stats);
+      } else {
+        return new ExecutorState<PropagatorStateType>(
+          args, immutable_state, kernel_stats);
+      }
+    }
+};
+
 template <class PropagatorStateType>
 template <typename Closure>
 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) {
@@ -511,7 +495,7 @@ void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
   } else {
     done_cb_ = std::move(done);
     // Schedule to run all the ready ops in thread pool.
-    ScheduleReady(&ready, nullptr);
+    this->ScheduleReady(&ready, nullptr);
   }
 }
 
@@ -730,7 +714,7 @@ void ExecutorState<PropagatorStateType>::Process(const TaggedNode& tagged_node,
                             profiler::TraceMeLevel::kVerbose);
   TaggedNodeReadyQueue inline_ready;
   inline_ready.push_back(tagged_node);
-  return ProcessInline(&inline_ready, scheduled_nsec);
+  return this->ProcessInline(&inline_ready, scheduled_nsec);
 }
 
 template <class PropagatorStateType>
@@ -773,6 +757,7 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
   params->slice_reader_cache = slice_reader_cache_;
   params->runner = &runner_;
   params->run_all_kernels_inline = run_all_kernels_inline_;
+  params->executor_policy = executor_policy_;
   params->stats_collector = stats_collector_;
   params->inc_num_deferred_ops_function = [this]() {
     mutex_lock lock(num_deferred_ops_mu_);
@@ -886,13 +871,13 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
     if (tagged_node.get_is_dead() && !item.is_transfer_node) {
       if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
     } else if (TF_PREDICT_FALSE(item.is_noop)) {
-      ProcessNoop(stats);
+      this->ProcessNoop(stats);
     } else if (item.const_tensor != nullptr && !params->track_allocations) {
-      ProcessConstTensor(item, &outputs, stats);
+      this->ProcessConstTensor(item, &outputs, stats);
     } else {
       // Prepares inputs.
       bool is_input_dead = false;
-      s = PrepareInputs(item, first_input, inputs.get(), &input_alloc_attrs,
+      s = this->PrepareInputs(item, first_input, inputs.get(), &input_alloc_attrs,
                         &is_input_dead);
       if (!s.ok()) {
         // Clear inputs.
@@ -903,7 +888,7 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
         propagator_.MaybeMarkCompleted(tagged_node);
         activity_watcher::ActivityEnd(activity_id);
         // Continue to process the nodes in 'inline_ready'.
-        completed = NodeDone(s, ready.get(), stats, inline_ready);
+        completed = this->NodeDone(s, ready.get(), stats, inline_ready);
         continue;
       }
 
@@ -918,11 +903,11 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
       params->input_alloc_attrs = input_alloc_attrs;
 
       if (item.kernel_is_async) {
-        ProcessAsync(item, *params, tagged_node, first_input, stats,
+        this->ProcessAsync(item, *params, tagged_node, first_input, stats,
                      activity_id);
         launched_asynchronously = true;
       } else {
-        s = ProcessSync(item, params.get(), &outputs, stats);
+        s = this->ProcessSync(item, params.get(), &outputs, stats);
       }
     }
 
@@ -957,12 +942,12 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
         scheduled_nsec = nodestats::NowInNsec();
       }
       // Postprocess.
-      completed = NodeDone(s, ready.get(), stats, inline_ready);
+      completed = this->NodeDone(s, ready.get(), stats, inline_ready);
     }
   }  // while !inline_ready.empty()
 
   // This thread of computation is done if completed = true.
-  if (completed) ScheduleFinish();
+  if (completed) this->ScheduleFinish();
 }
 
 template <class PropagatorStateType>
@@ -1202,7 +1187,7 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
       }
 
       // Schedule the ready nodes in 'ready'.
-      ScheduleReady(ready, inline_ready);
+      this->ScheduleReady(ready, inline_ready);
 
       return false;
     }
@@ -1388,7 +1373,7 @@ void ExecutorState<PropagatorStateType>::ScheduleFinish() {
   // Finish is always called exactly once per ExecutorState, either here if
   // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
   // there are deferred ops.
-  Finish();
+  this->Finish();
 }
 
 template <class PropagatorStateType>
@@ -1507,17 +1492,44 @@ void ExecutorState<PropagatorStateType>::Finish() {
   }
 }
 
+template <class PropagatorStateType>
+void BatchSchedulingExecutorState<PropagatorStateType>::ScheduleReady(
+    TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
+  DCHECK(!ready->empty());
+
+  int64_t scheduled_nsec = 0;
+  if (this->stats_collector_) {
+    scheduled_nsec = nodestats::NowInNsec();
+  }
+
+  if (inline_ready == nullptr) {
+    // Schedule all ready kernels from a single closure. This ensure that,
+    // regardless of the `runner_` implementation, all kernels will run
+    // sequentially on the same thread, and thread wakeup overhead and
+    // executor mutex contention will be minimized.
+    this->RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
+      for (auto& tagged_node : ready) {
+        this->Process(tagged_node, scheduled_nsec);
+      }
+    });
+  } else {
+    for (auto& tagged_node : *ready) {
+      inline_ready->push_back(tagged_node);
+    }
+  }
+
+  ready->clear();
+}
+
 void ExecutorImpl::RunAsyncInternal(const Args& args, DoneCallback done) {
   if (OpOrderDeterminismRequired()) {
-    (new ExecutorState<OrderedPropagatorState>(args, immutable_state_,
-                                               &kernel_stats_))
+    (ExecutorStateFactory::Create<OrderedPropagatorState>(args, this))
         ->RunAsync(std::move(done));
   } else if (immutable_state_.requires_control_flow_support()) {
-    (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
+    (ExecutorStateFactory::Create<PropagatorState>(args, this))
         ->RunAsync(std::move(done));
   } else {
-    (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
-                                              &kernel_stats_))
+    (ExecutorStateFactory::Create<SimplePropagatorState>(args, this))
         ->RunAsync(std::move(done));
   }
 }
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 607cf9b..4f3b8f6 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -124,6 +124,7 @@ class Executor {
     // If true, all kernels will be treated as "inexpensive", and hence executed
     // on the scheduling thread.
     bool run_all_kernels_inline = false;
+    ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
   };
   typedef std::function<void(const Status&)> DoneCallback;
 
diff --git a/tensorflow/tensorflow/core/common_runtime/kernel_stat.h b/tensorflow/core/common_runtime/kernel_stat.h
new file mode 100644
index 0000000..670a9f2
--- /dev/null
+++ b/tensorflow/core/common_runtime/kernel_stat.h
@@ -0,0 +1,104 @@
+/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
+
+#include <atomic>
+#include <memory>
+#include <queue>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/time/time.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/common_runtime/graph_view.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/env_var.h"
+
+namespace tensorflow {
+namespace ExecutorInternal {
+
+// Stores execution time information about the kernels in an executor's graph.
+class KernelStats {
+  public:
+  KernelStats() = default;
+
+  void Initialize(const GraphView& gview) {
+    is_expensive_.resize(gview.num_nodes());
+    cost_estimates_ =
+        std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
+    for (int32_t i = 0; i < gview.num_nodes(); ++i) {
+      if (gview.node(i)) {
+        is_expensive_[i] =
+            gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
+        cost_estimates_[i] = kInitialCostEstimateCycles;
+      }
+    }
+  }
+
+  // Returns true iff the given node is considered "expensive". The
+  // executor uses this flag to optimize graph execution, for example
+  // by "inlining" inexpensive kernels.
+  bool IsExpensive(const NodeItem& node) const {
+    return is_expensive_[node.node_id] &&
+            (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
+            kOpIsExpensiveThresholdCycles);
+  }
+
+  // Returns the value of kernel->IsExpensive().
+  bool HasExpensiveMarker(const NodeItem& node) const {
+    return is_expensive_[node.node_id];
+  }
+
+  // Updates the dynamic cost estimate, which is used to determine whether the
+  // given node is expensive. The new cost estimate is a weighted average of
+  // the old cost estimate and the latest cost. We only update cost estimates
+  // for kernels for which IsExpensive() return true.
+  void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
+    // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
+    // updates may result in one or more updates being ignored.  This does not
+    // affect correctness but may slow down the update frequency.
+    std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
+    auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
+
+    uint64 new_estimate =
+        ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
+
+    cost_estimate.store(new_estimate, std::memory_order_relaxed);
+  }
+
+  private:
+  // Initial time (in CPU cycles) we expect an operation to take.  Used to
+  // determine whether an operation should be place in a threadpool.
+  // Operations start out "expensive".
+  static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
+  static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
+  static constexpr uint64 kCostDecay = 10;
+
+  std::vector<bool> is_expensive_;
+  // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
+  std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
+};
+} // namespace ExecutorInternal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
\ No newline at end of file
diff --git a/tensorflow/core/common_runtime/next_pluggable_device/flags.h b/tensorflow/core/common_runtime/next_pluggable_device/flags.h
index 681155e..29b056b 100644
--- a/tensorflow/core/common_runtime/next_pluggable_device/flags.h
+++ b/tensorflow/core/common_runtime/next_pluggable_device/flags.h
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_FLAGS_H_
 
 #include "absl/flags/declare.h"
+#include "absl/flags/flag.h"
 
 ABSL_DECLARE_FLAG(bool, next_pluggable_device_use_c_api);
 
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index a70034b..492bfcf 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -672,6 +672,7 @@ class OpKernelContext {
     StepStatsCollectorInterface* stats_collector = nullptr;
     GraphCollector* graph_collector = nullptr;
     bool run_all_kernels_inline = false;
+    ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
     const std::string* executor_type = nullptr;
 
     // TensorSliceReaderCache support.
@@ -833,6 +834,9 @@ class OpKernelContext {
     return params_->run_all_kernels_inline;
   }
 
+  ExecutorPolicy executor_policy() const {
+    return params_->executor_policy;
+  }
   // Returns the registered name for the executor type that is executing the
   // current kernel. If empty, the default executor is used.
   const std::string& executor_type() const;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 22617ef..7cdceb5 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -38,6 +38,15 @@ load(
     "if_mkl",
     "mkl_deps",
 )
+load(
+    "//third_party/KDNN:build_defs.bzl",
+    "if_enable_kdnn",
+    "kdnn_deps",
+)
+load(
+    "//third_party/ktfop:build_defs.bzl",
+    "if_ktfop",
+)
 load(
     "@local_config_rocm//rocm:build_defs.bzl",
     "if_rocm",
@@ -855,7 +864,7 @@ ARRAY_DEPS = [
     "//tensorflow/core/framework:bounds_check",
     "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation",
     "@eigen_archive//:eigen3",
-]
+] + kdnn_deps()
 
 tf_kernel_library(
     name = "immutable_constant_op",
@@ -3488,7 +3497,7 @@ tf_kernel_library(
         "//tensorflow/core/protobuf:autotuning_proto_cc",
         "//tensorflow/core/util/autotune_maps:conv_parameters",
         "//tensorflow/core/util/proto:proto_utils",
-    ]),
+    ]) + kdnn_deps(),
 )
 
 cc_library(
@@ -3568,7 +3577,7 @@ tf_kernel_library(
     prefix = "cwise_op",
     deps = MATH_DEPS + [
         "//tensorflow/core/kernels/mlir_generated:cwise_op",
-    ],
+    ] + kdnn_deps(),
 )
 
 tf_kernel_library(
@@ -4296,7 +4305,7 @@ tf_kernel_library(
     ]) + [
         ":gpu_prim_hdrs",
         ":loose_headers",
-    ],
+    ] + kdnn_deps(),
 )
 
 tf_kernel_library(
@@ -5136,7 +5145,7 @@ tf_kernel_library(
         "//tensorflow/core/framework:bounds_check",
         "//tensorflow/core/util:determinism_for_kernels",
         "@eigen_archive//:eigen3",
-    ],
+    ] + kdnn_deps(),
 )
 
 tf_kernel_library(
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index 0b66076..39f8e5c 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -31,6 +31,11 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
 
 namespace tensorflow {
 
@@ -172,6 +177,13 @@ class ConcatBaseOp : public OpKernel {
         return;
       }
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if defined(ENABLE_KDNN)
+      KDNN::Element::TypeT kdnnType = KDNN::Element::TypeAdapter<T>::value;
+      if(IsKDNNEnabled() && kdnnType != KDNN::Element::TypeT::UNDEFINED) {
+        KDNNConcatImpl<T>(c, inputs_flat, &output_flat);
+        return;
+      }
+#endif
       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
     }
   }
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index fd7ee45..8d6554a 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -35,6 +35,10 @@ limitations under the License.
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/util/bcast.h"
 
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -109,6 +113,15 @@ class BinaryOp : public BinaryOpShared {
       Tensor* out;
       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
                               {0, 1}, 0, input_0.shape(), &out));
+#if defined(ENABLE_KDNN)
+    if constexpr (std::is_same<Functor, functor::safe_floor_mod<int64_t>>::value ||
+        std::is_same<Functor, functor::floor_fmod<float>>::value) {
+        if (IsKDNNEnabled()) {
+            kdnnFloormodOp<Functor>(ctx, input_0, input_1, out);
+            return;
+        }
+    }
+#endif
       functor::BinaryFunctor<Device, Functor, 1>()(
           eigen_device, out->template flat<Tout>(),
           input_0.template flat<Tin>(), input_1.template flat<Tin>(),
@@ -322,6 +335,16 @@ class UnaryOp : public OpKernel {
     } else {
       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
     }
+
+#if defined(ENABLE_KDNN)
+    if constexpr (std::is_same<Functor, functor::sigmoid<float>>::value) {
+        if (IsKDNNEnabled() & inp.NumElements() > 0) {
+            kdnnSigmoidOp<Functor>(ctx, inp, out);
+            return;
+        }
+    }
+#endif
+
     functor::UnaryFunctor<Device, Functor>()(
         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
   }
diff --git a/tensorflow/tensorflow/core/kernels/fused_embedding/BUILD b/tensorflow/core/kernels/fused_embedding/BUILD
new file mode 100644
index 0000000..3774125
--- /dev/null
+++ b/tensorflow/core/kernels/fused_embedding/BUILD
@@ -0,0 +1,45 @@
+load("//tensorflow:tensorflow.default.bzl", "tf_kernel_library")
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
+)
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+)
+
+DYNAMIC_DEPS = [
+    "//tensorflow/core/framework:bounds_check",
+    "//tensorflow/core:core_cpu",
+    "//tensorflow/core:framework",
+    "//tensorflow/core:lib",
+    "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+    name = "fused_embedding_ops",
+    srcs = [
+        "embedding_lookup_hash_op.cc",
+        "lookup_embedding_by_hash.h",
+    ],
+    deps = [
+        "@eigen_archive//:eigen3",
+        "@farmhash_archive//:farmhash",
+    ] + DYNAMIC_DEPS
+)
+
+tf_cc_test(
+    name = "fused_embedding_ops_test",
+    srcs = ["fused_embedding_with_hash_bucket_ops_test.cc"],
+    deps = [
+        ":fused_embedding_ops",
+        "//tensorflow/core/ops:fused_embedding_ops_op_lib",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
\ No newline at end of file
diff --git a/tensorflow/tensorflow/core/kernels/fused_embedding/embedding_lookup_hash_op.cc b/tensorflow/core/kernels/fused_embedding/embedding_lookup_hash_op.cc
new file mode 100644
index 0000000..e5140d9
--- /dev/null
+++ b/tensorflow/core/kernels/fused_embedding/embedding_lookup_hash_op.cc
@@ -0,0 +1,92 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_var.h"
+#include "tensorflow/core/framework/bounds_check.h"
+
+#include "lookup_embedding_by_hash.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+struct LookupEmbeddingByHashFunctor {
+  int operator()(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table, 
+                 int64_t embedding_size, int32_t embedding_dims, float *output) {
+    return KPLookupEmbeddingByHashImpl::Compute(lookup_embedding, lookup_length, batch_size, embedding_table, 
+                                                embedding_size, embedding_dims, output);
+  }
+};
+
+class KPLookupEmbeddingByHashOp : public OpKernel {
+    public:
+      explicit KPLookupEmbeddingByHashOp(OpKernelConstruction* context)
+               : OpKernel(context) {
+        OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+        node_name = context->def().name();
+      }
+
+      void Compute(OpKernelContext* context) override {
+        float *weight;
+        const Tensor& input_tensor = context->input(0);
+        const Tensor* weight_tensor = &context->input(1);
+        
+        if (weight_tensor->dtype() == DT_RESOURCE) {
+          Var* variable;
+          OP_REQUIRES_OK(context,
+                        LookupResource(context, HandleFromInput(context, 1), 
+                                        &variable));
+          core::ScopedUnref s(variable);
+          weight_tensor = variable->tensor();
+          OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+                      errors::InvalidArgument("Expect float weight in ",
+                                              node_name));
+        }
+        
+        auto input = input_tensor.flat<tstring>();
+        weight = (float *)weight_tensor->tensor_data().data();
+        int64_t batch = input_tensor.dim_size(0);
+        int64_t embedding_dims = weight_tensor->dim_size(1);
+        uintptr_t cstr_addresses[batch];
+        size_t cstr_length[batch];
+        for (int i = 0; i < batch; ++i) {
+          cstr_addresses[i] = reinterpret_cast<uintptr_t>(input(i).c_str());
+          cstr_length[i] = input(i).length();
+        }
+        Tensor* output_tensor = nullptr;
+        OP_REQUIRES_OK(context, 
+                       context->allocate_output(
+                        0, TensorShape({batch, embedding_dims}), 
+                        &output_tensor));
+        float *output = (float *)output_tensor->tensor_data().data();
+        LookupEmbeddingByHashFunctor lookfunctor;
+        int result = lookfunctor(cstr_addresses, cstr_length, batch,
+                                 weight, num_buckets_, embedding_dims, output);
+        OP_REQUIRES(context, (result == 0),
+                errors::InvalidArgument("Invalid argument, error code: ", result));
+      }
+
+    private:
+        int64_t num_buckets_;
+        std::string node_name;
+};
+REGISTER_KERNEL_BUILDER(Name("KPLookupEmbeddingByHash").Device(DEVICE_CPU), KPLookupEmbeddingByHashOp);
+}  // namespace tensorflow
diff --git a/tensorflow/tensorflow/core/kernels/fused_embedding/fused_embedding_with_hash_bucket_ops_test.cc b/tensorflow/core/kernels/fused_embedding/fused_embedding_with_hash_bucket_ops_test.cc
new file mode 100644
index 0000000..03ed29a
--- /dev/null
+++ b/tensorflow/core/kernels/fused_embedding/fused_embedding_with_hash_bucket_ops_test.cc
@@ -0,0 +1,59 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class KPLookupEmbeddingByHashOpTest : public OpsTestBase {};
+
+TEST_F(KPLookupEmbeddingByHashOpTest, WithHashBucket) {
+  TF_EXPECT_OK(
+      NodeDefBuilder("fused_embedding_with_hash_bucket_op", "KPLookupEmbeddingByHash")
+          .Input(FakeInput(DT_STRING))
+          .Input(FakeInput(DT_FLOAT))
+          .Attr("num_buckets", 5)
+          .Attr("combiner", 1)
+          .Attr("T_weight", DT_FLOAT)
+          .Finalize(node_def()));
+  TF_EXPECT_OK(InitOp());
+  AddInputFromArray<tstring>(TensorShape({2, 1}),
+                           {"ktfop", "fused_embedding_with_hash_bucket"});
+  AddInputFromArray<float>(TensorShape({5, 10}), {3.21, 7.89, 1.45, 9.32, 0.67, 5.43, 2.98, 8.76, 4.12, 6.54,
+                                                  0.23, 9.87, 3.56, 7.01, 2.34, 8.09, 5.67, 1.89, 6.78, 4.45,
+                                                  0.98, 9.01, 3.45, 7.23, 2.67, 8.34, 5.89, 1.23, 6.45, 4.78,
+                                                  0.56, 9.45, 3.78, 7.56, 2.12, 8.67, 5.34, 1.67, 6.23, 4.89,
+                                                  0.34, 9.78, 3.12, 7.45, 2.89, 8.23, 5.78, 1.45, 6.67, 4.23});
+  TF_ASSERT_OK(RunOpKernel());
+
+  Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 10}));
+  test::FillValues<float>(
+      &expected, {3.21, 7.89, 1.45, 9.32, 0.67, 5.43, 2.98, 8.76, 4.12, 6.54,
+                  0.98, 9.01, 3.45, 7.23, 2.67, 8.34, 5.89, 1.23, 6.45, 4.78});
+  test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.0);
+}
+}  // namespace tensorflow
diff --git a/tensorflow/tensorflow/core/kernels/fused_embedding/lookup_embedding_by_hash.h b/tensorflow/core/kernels/fused_embedding/lookup_embedding_by_hash.h
new file mode 100644
index 0000000..380faf1
--- /dev/null
+++ b/tensorflow/core/kernels/fused_embedding/lookup_embedding_by_hash.h
@@ -0,0 +1,82 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <farmhash.h>
+#include <cstring>
+
+namespace tensorflow {
+
+enum ReturnCode {
+    OK = 0,
+    NULL_POINTER = 1,
+    INVALID_PARAMETER = 2
+};
+
+static inline int Lookup1D(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table, 
+                           int64_t embedding_size, int32_t embedding_dims, float *output)
+{
+    uint64_t embedding_size_u64 = static_cast<uint64_t>(embedding_size);
+    for (int64_t i = 0; i < batch_size; ++i) {
+        if (lookup_length[i] != 0) {
+            uint64_t hash_value = ::util::Fingerprint64((char *)(lookup_embedding[i]), lookup_length[i]);
+            uint64_t x = hash_value % embedding_size_u64;
+            output[i] = embedding_table[x];
+        } else {
+            output[i] = 0;
+        }
+    }
+    return OK;
+}
+
+static inline int RegularLookup(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table, 
+                                int64_t embedding_size, int32_t embedding_dims, float *output)
+{
+    uint64_t embedding_dims_u64 = static_cast<uint64_t>(embedding_dims);
+    uint64_t embedding_size_u64 = static_cast<uint64_t>(embedding_size);
+    for (int64_t i = 0; i < batch_size; ++i) {
+        if (lookup_length[i] != 0) {
+            uint64_t hash_value = ::util::Fingerprint64((char *)(lookup_embedding[i]), lookup_length[i]);
+
+            uint64_t x = hash_value % embedding_size_u64;
+            for (uint64_t j = 0; j < embedding_dims; ++j) {
+                output[j] = embedding_table[x * embedding_dims + j];
+            }
+            output += embedding_dims_u64;
+        }
+    }
+    return OK;
+}
+
+struct KPLookupEmbeddingByHashImpl {
+    static int Compute(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table, 
+                int64_t embedding_size, int32_t embedding_dims, float *output) {
+    if (output == nullptr || lookup_embedding == nullptr || embedding_table == nullptr || lookup_length == nullptr) {
+        return NULL_POINTER;
+    }
+    if (batch_size < 0 || embedding_size <= 0 || embedding_dims <= 0) {
+        return INVALID_PARAMETER;
+    }
+
+    if (embedding_dims == 1) {
+        return Lookup1D(lookup_embedding, lookup_length, batch_size, embedding_table, 
+                        embedding_size, embedding_dims, output);
+    } else {
+        return RegularLookup(lookup_embedding, lookup_length, batch_size, embedding_table, 
+                             embedding_size, embedding_dims, output);
+    }
+    }
+};
+
+} // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/tensorflow/core/kernels/ktfop/BUILD b/tensorflow/core/kernels/ktfop/BUILD
new file mode 100644
index 0000000..a6fa0e7
--- /dev/null
+++ b/tensorflow/core/kernels/ktfop/BUILD
@@ -0,0 +1,36 @@
+load("//tensorflow:tensorflow.default.bzl", "tf_kernel_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//tensorflow:internal"],
+)
+
+DYNAMIC_DEPS = [
+    "//tensorflow/core/framework:bounds_check",
+    "//tensorflow/core:core_cpu",
+    "//tensorflow/core:framework",
+    "//tensorflow/core:lib",
+    "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+    name = "fused_embedding_ops",
+    srcs = [
+        "embedding_lookup_op.cc",
+    ],
+    deps = [
+        "@eigen_archive//:eigen3",
+        "@ktfop_archive//:ktfop",
+    ] + DYNAMIC_DEPS
+)
+
+tf_kernel_library(
+    name = "softmax_ops",
+    srcs = [
+        "softmax.cc",
+    ],
+    deps = [
+        "@eigen_archive//:eigen3",
+        "@ktfop_archive//:ktfop",
+    ] + DYNAMIC_DEPS
+)
\ No newline at end of file
diff --git a/tensorflow/tensorflow/core/kernels/ktfop/embedding_lookup_op.cc b/tensorflow/core/kernels/ktfop/embedding_lookup_op.cc
new file mode 100644
index 0000000..1a4d4dc
--- /dev/null
+++ b/tensorflow/core/kernels/ktfop/embedding_lookup_op.cc
@@ -0,0 +1,164 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_var.h"
+#include "tensorflow/core/framework/bounds_check.h"
+
+#include "ktfop.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+class KPFusedEmbeddingOp : public OpKernel {
+public:
+  explicit KPFusedEmbeddingOp(OpKernelConstruction* context)
+           : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_));
+    node_name = context->def().name();
+  }
+
+  ~KPFusedEmbeddingOp() {}
+
+  void Compute(OpKernelContext* context) override {
+    float *weight;
+    const Tensor* weight_tensor = &context->input(0);
+
+    if (weight_tensor->dtype() == DT_RESOURCE) {
+      Var* variable;
+      OP_REQUIRES_OK(context,
+                     LookupResource(context, HandleFromInput(context, 0), 
+                                    &variable));
+      core::ScopedUnref s(variable);
+      weight_tensor = variable->tensor();
+      OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+                  errors::InvalidArgument("Expect float weight in ",
+                                          node_name));
+    }
+
+    weight = (float *)weight_tensor->tensor_data().data();
+    
+    const Tensor& input_tensor = context->input(1);
+    int64 *input = (int64 *)input_tensor.tensor_data().data();
+    const Tensor& shape_tensor = context->input(2);
+    int64 *shape = (int64 *)shape_tensor.tensor_data().data();
+
+    OP_REQUIRES(context, (shape_tensor.dims() == 1),
+                errors::InvalidArgument("Shape tensor is not valid (dims != 1)"));
+    OP_REQUIRES(context, (shape_tensor.dim_size(0) >= 2),
+                errors::InvalidArgument("Shape tensor is not valid (dim_size(0) < 2)"));
+    
+    int64 input_size = 1;
+    for (int i = 0; i < input_tensor.dims(); ++i) {
+      input_size *= input_tensor.dim_size(i);
+    }
+    int input_dims = shape_tensor.dim_size(0);
+    int cols = shape[input_dims - 1];
+    int batch_size = 1;
+    for (int i = 0; i < input_dims - 1; ++i) {
+      batch_size *= shape[i];
+    }
+    OP_REQUIRES(context, (input_size == batch_size * cols),
+                errors::InvalidArgument("input id is dense"));
+    int embedding_dims = weight_tensor->dim_size(1);
+    bool is_mean = (combiner_ == 1);
+
+    Tensor* output_tensor = NULL;
+    TensorShape output_shape({batch_size, embedding_dims});
+    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+                                                     &output_tensor));
+    float *output = (float *)output_tensor->tensor_data().data();
+    ktfop::EmbeddingParams params(input,
+                                  batch_size,
+                                  cols,
+                                  weight,
+                                  embedding_dims,
+                                  is_mean);
+    int result = ktfop::FusedEmbedding(params, output);
+    OP_REQUIRES(context, (result == 0),
+                errors::InvalidArgument("Invalid argument, error code: ", result));
+  }
+
+private:
+  int combiner_;
+  std::string node_name;
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbedding").Device(DEVICE_CPU), KPFusedEmbeddingOp);
+
+class KPFusedEmbeddingWithHashBucketOp : public OpKernel {
+    public:
+      explicit KPFusedEmbeddingWithHashBucketOp(OpKernelConstruction* context)
+               : OpKernel(context) {
+        OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+        node_name = context->def().name();
+      }
+
+      void Compute(OpKernelContext* context) override {
+        float *weight;
+        const Tensor& input_tensor = context->input(0);
+        const Tensor* weight_tensor = &context->input(1);
+        
+        if (weight_tensor->dtype() == DT_RESOURCE) {
+          Var* variable;
+          OP_REQUIRES_OK(context,
+                        LookupResource(context, HandleFromInput(context, 1), 
+                                        &variable));
+          core::ScopedUnref s(variable);
+          weight_tensor = variable->tensor();
+          OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+                      errors::InvalidArgument("Expect float weight in ",
+                                          node_name));
+        }
+        
+        auto input = input_tensor.flat<tstring>();
+        weight = (float *)weight_tensor->tensor_data().data();
+        int64_t batch = input_tensor.dim_size(0);
+        int64_t embedding_dims = weight_tensor->dim_size(1);
+        uintptr_t cstr_addresses[batch];
+        size_t cstr_length[batch];
+        for (int i = 0; i < batch; ++i) {
+          cstr_addresses[i] = reinterpret_cast<uintptr_t>(input(i).c_str());
+          cstr_length[i] = input(i).length();
+        }
+        Tensor* output_tensor = nullptr;
+        OP_REQUIRES_OK(context, 
+                       context->allocate_output(
+                        0, TensorShape({batch, embedding_dims}), 
+                        &output_tensor));
+        float *output = (float *)output_tensor->tensor_data().data();
+        ktfop::EmbeddingParamsWithHash params(cstr_addresses,
+                                              cstr_length,
+                                              batch,
+                                              weight,
+                                              num_buckets_,
+                                              embedding_dims);
+        int result = ktfop::FusedEmbeddingWithHashBucket(params, output);
+        OP_REQUIRES(context, (result == 0),
+                errors::InvalidArgument("Invalid argument, error code: ", result));
+      }
+
+    private:
+        int64_t num_buckets_;
+        std::string node_name;
+};
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingWithHashBucket").Device(DEVICE_CPU), KPFusedEmbeddingWithHashBucketOp);
+}  // namespace tensorflow
diff --git a/tensorflow/tensorflow/core/kernels/ktfop/softmax.cc b/tensorflow/core/kernels/ktfop/softmax.cc
new file mode 100644
index 0000000..9384a85
--- /dev/null
+++ b/tensorflow/core/kernels/ktfop/softmax.cc
@@ -0,0 +1,52 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+#include "unsupported/Eigen/CXX11/Tensor"  // from @eigen_archive
+#include "ktfop.h"
+
+namespace tensorflow {
+class KPSoftmaxOp : public OpKernel {
+ public:
+  explicit KPSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& logits_in = context->input(0);
+    OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in.shape()),
+                errors::InvalidArgument("logits must have >= 1 dimension, got ",
+                                        logits_in.shape().DebugString()));
+    Tensor* softmax_out = nullptr;
+    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+                                {0}, 0, logits_in.shape(), &softmax_out));
+    if (logits_in.NumElements() > 0) {
+      typename TTypes<float>::ConstMatrix input = logits_in.flat_inner_dims<float>();
+      float* input_data = (float *)logits_in.data();
+      float* output_data = (float *)softmax_out->data();
+      int result = ktfop::Softmax(input_data, output_data, input.dimension(0), input.dimension(1));
+      OP_REQUIRES(context, (result == 0),
+                errors::InvalidArgument("Invalid argument, error code: ", result));
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("KPSoftmax").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+    KPSoftmaxOp);
+
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h
index 99c1306..8cf4b39 100644
--- a/tensorflow/core/kernels/linalg/einsum_op_impl.h
+++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h
@@ -42,6 +42,11 @@ limitations under the License.
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/profiler/lib/traceme.h"
 #include "tensorflow/core/util/einsum_op_util.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
@@ -466,6 +471,16 @@ struct EinsumHelper {
       set_zero(ctx->eigen_device<Device>(), output->flat<T>());
       return OkStatus();
     }
+#if defined(ENABLE_KDNN)
+    if (IsKDNNEnabled()
+        && std::is_same<T, float>::value
+        && inputs.size() == 2
+        && inputs[0].dims() >= 2 && inputs[0].dims() <= 5
+        && inputs[1].dims() >= 2 && inputs[1].dims() <= 5) {
+      kdnnBatchGemm(ctx, inputs[0], inputs[1], output, trans_x, trans_y);
+      return OkStatus();
+    }
+#endif
     Tensor output_reshaped;
     TF_RETURN_IF_ERROR(
         ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc
index 3908c79..474ee77 100644
--- a/tensorflow/core/kernels/matmul_op_fused.cc
+++ b/tensorflow/core/kernels/matmul_op_fused.cc
@@ -46,6 +46,7 @@ limitations under the License.
 #include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/util/matmul_autotune.h"
 #include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/port.h"
 
 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
 #include "tsl/framework/contraction/eigen_contraction_kernel.h"
@@ -69,6 +70,10 @@ limitations under the License.
 #include "tensorflow/core/util/use_cudnn.h"
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -93,6 +98,16 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
     OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
                 errors::InvalidArgument("_FusedMatMul doesn't support DT_HALF "
                                         "data type on CPU devices."));
+#if defined(ENABLE_KDNN)
+    bool transpose_a_ = dim_pair[0].first == 0;
+    bool transpose_b_ = dim_pair[0].second == 1;
+    bool fusion_relu = fusion == FusedComputationType::kBiasAddWithRelu;
+    bool kdnn_enable_fusion = (fusion == FusedComputationType::kBiasAdd) || fusion_relu;
+    if (IsKDNNEnabled() && std::is_same<T, float>::value && kdnn_enable_fusion && !transpose_a_) {
+      kdnnFusedGemm(context, a, b, output, fusion_relu, transpose_a_, transpose_b_);
+      return;
+    }
+#endif
     auto lhs = a.matrix<T>();
     auto rhs = b.matrix<T>();
     auto out = output->matrix<T>();
diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h
index c647780..4ceb42a 100644
--- a/tensorflow/core/kernels/matmul_op_impl.h
+++ b/tensorflow/core/kernels/matmul_op_impl.h
@@ -41,6 +41,7 @@ limitations under the License.
 #include "tensorflow/core/util/matmul_autotune.h"
 #include "tensorflow/core/util/matmul_bcast.h"
 #include "tensorflow/core/util/work_sharder.h"
+#include "tensorflow/core/util/port.h"
 
 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
 #include "tsl/framework/contraction/eigen_contraction_kernel.h"
@@ -64,6 +65,10 @@ limitations under the License.
 #endif
 #endif
 
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -927,6 +932,24 @@ class BaseBatchMatMulOp : public OpKernel {
       f(ctx->eigen_device<Device>(), out->flat<Tout>());
       return;
     }
+#if defined(ENABLE_KDNN)
+    int dims = in0.shape().dims();
+    bool trans_x = (adj_x_ || trans_x_);
+    bool trans_y = (adj_y_ || trans_y_);
+
+    if (IsKDNNEnabled()
+        && std::is_same<Ta, float>::value
+        && std::is_same<Tb, float>::value
+        && std::is_same<Tout, float>::value
+        && dims >= 2 && dims <= 5) {
+      if (batch_size == 1) {
+        kdnnGemm(ctx, in0_reshaped, in1_reshaped, out, trans_x, trans_y);
+      } else {
+        kdnnBatchGemm(ctx, in0, in1, out, trans_x, trans_y);
+      }
+      return;
+    }
+#endif
     Tensor out_reshaped;
     OP_REQUIRES(ctx,
                 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index ee99bd2..a403848 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -25,6 +25,10 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/kernels/softmax_op_functor.h"
 
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -61,6 +65,15 @@ class SoftmaxOp : public OpKernel {
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {0}, 0, logits_in.shape(), &softmax_out));
     if (logits_in.NumElements() > 0) {
+#if defined(ENABLE_KDNN)
+    if constexpr (std::is_same<T, float>::value) {
+        if (!log_ && IsKDNNEnabled()) {
+            kdnnSoftmaxOp<T>(context, logits_in, softmax_out);
+            return;
+        }
+    }
+#endif
+
       functor::SoftmaxFunctor<Device, T> functor;
       functor(context->eigen_device<Device>(), logits_in.flat_inner_dims<T>(),
               softmax_out->flat_inner_dims<T>(), log_);
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 73870c5..50544ad 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -24,6 +24,11 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/kernels/fill_functor.h"
 #include "tensorflow/core/platform/bfloat16.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
 
 namespace tensorflow {
 
@@ -133,6 +138,28 @@ class SparseTensorDenseMatMulOp : public OpKernel {
       return;
     }
 
+#if defined(ENABLE_KDNN)
+#define KDNN_ADJOINT(ADJ_A, ADJ_B)                                             \
+  if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                            \
+    Status functor_status = functor::KDNNSparseMatMulFunctor<                  \
+        CPUDevice, float, Tindices, ADJ_A,                                     \
+        ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<float>(),     \
+                        a_indices->matrix<Tindices>(), a_values->vec<float>(), \
+                        b->matrix<float>());                                   \
+    OP_REQUIRES_OK(ctx, functor_status);                                       \
+  }
+    const int int32max = std::numeric_limits<int>::max();
+    if (IsKDNNEnabled() && std::is_same<T, float>::value && adjoint_a_ == false &&
+        FastBoundsCheck(inner_left, int32max) &&
+        FastBoundsCheck(inner_right, int32max) &&
+        FastBoundsCheck(outer_left, int32max)) {
+      KDNN_ADJOINT(false, false);
+      KDNN_ADJOINT(false, true);
+      return;
+    }
+#undef KDNN_ADJOINT
+#endif
+
 #define MAYBE_ADJOINT(ADJ_A, ADJ_B)                                           \
   if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                           \
     Status functor_status = functor::SparseTensorDenseMatMulFunctor<          \
@@ -355,6 +382,56 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
   }
 };
 
+#if defined(ENABLE_KDNN)
+template <typename Tindices, bool ADJ_A, bool ADJ_B>
+struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
+  static const std::size_t kNumVectorize = 32;
+
+  static Status Compute(const CPUDevice& d, typename TTypes<float>::Matrix out,
+                        typename TTypes<Tindices>::ConstMatrix a_indices,
+                        typename TTypes<float>::ConstVec a_values,
+                        typename TTypes<float>::ConstMatrix b) {
+    const std::size_t nnz = a_values.size();
+    const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
+    const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
+    const int lhs_index_a = ADJ_A ? 1 : 0;
+    const int rhs_index_a = ADJ_A ? 0 : 1;
+
+    out.setZero();
+
+    if (rhs_right < kNumVectorize) {
+      // Disable vectorization if the RHS of output is too small
+      auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
+
+      for (std::size_t i = 0; i < nnz; ++i) {
+        const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
+        const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
+        if (!FastBoundsCheck(k, lhs_right)) {
+          return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
+        }
+        if (!FastBoundsCheck(m, out.dimension(0))) {
+          return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
+        }
+        const float a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
+        for (std::size_t n = 0; n < rhs_right; ++n) {
+          const float b_value = maybe_adjoint_b(k, n);
+          out(m, n) += a_value * b_value;
+        }
+      }
+    } else {
+      const float* b_data = b.data();
+      Eigen::Tensor<float, 2, Eigen::ColMajor> col_major_conj_b;
+      if (ADJ_B) {
+        Eigen::array<int, 2> shuffle(1, 0);
+        col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate().eval();
+        b_data = col_major_conj_b.data();
+      }
+      kdnnSparseMatmul<Tindices>(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
+    }
+    return OkStatus();
+  }
+};
+#endif
 }  // namespace functor
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index 3cab997..3de9474 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -35,6 +35,17 @@ struct SparseTensorDenseMatMulFunctor {
       typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
 };
 
+#if defined(ENABLE_KDNN)
+template <typename Device, typename T, typename Tindices, bool ADJ_A,
+          bool ADJ_B>
+struct KDNNSparseMatMulFunctor {
+  static EIGEN_ALWAYS_INLINE Status Compute(
+      const Device& d, typename TTypes<T>::Matrix out,
+      typename TTypes<Tindices>::ConstMatrix a_indices,
+      typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
+};
+#endif
+
 template <typename MATRIX, bool ADJ>
 class MaybeAdjoint;
 
diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD
index 91d80b6..4659963 100644
--- a/tensorflow/core/ops/BUILD
+++ b/tensorflow/core/ops/BUILD
@@ -18,7 +18,14 @@ load(
     "//third_party/mkl:build_defs.bzl",
     "if_mkl",
 )
-
+load(
+    "//third_party/ktfop:build_defs.bzl",
+    "if_ktfop",
+)
+load(
+    "//third_party/fused_embedding:build_defs.bzl",
+    "if_fused_embedding",
+)
 # A lot of packages try to minimize binary size by depending on individual ops,\
 # so they need access here.
 package(
@@ -60,6 +67,8 @@ tf_gen_op_libs(
         "functional_ops",
         "image_ops",
         "io_ops",
+        "ktfop_ops",
+        "fused_embedding_ops",
         "linalg_ops",
         "list_ops",
         "map_ops",
@@ -344,6 +353,10 @@ cc_library(
     }) + if_mkl([
         ":mkl_array_ops_op_lib",
         ":mkl_nn_ops_op_lib",
+    ]) + if_ktfop([
+        ":ktfop_ops_op_lib",
+    ]) + if_fused_embedding([
+        ":fused_embedding_ops_op_lib",
     ]),
     alwayslink = 1,
 )
diff --git a/tensorflow/tensorflow/core/ops/fused_embedding_ops.cc b/tensorflow/core/ops/fused_embedding_ops.cc
new file mode 100644
index 0000000..80c9022
--- /dev/null
+++ b/tensorflow/core/ops/fused_embedding_ops.cc
@@ -0,0 +1,48 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPLookupEmbeddingByHash")
+    .Input("lookup: string")
+    .Input("weights: T_weight")
+    .Attr("num_buckets: int >= 1")
+    .Attr("combiner: int")
+    .Attr("T_weight: {resource, float}")
+    .Output("output: float")
+    .SetShapeFn([](InferenceContext* ctx) {
+      ShapeHandle temp;
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &temp));
+      DimensionHandle emb_size_dim = ctx->UnknownDim();
+      DimensionHandle batch_dim = ctx->UnknownDim();
+
+      ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+      ctx->set_output(0, output_shape);
+
+      return OkStatus();
+    });
+
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/tensorflow/core/ops/ktfop_ops.cc b/tensorflow/core/ops/ktfop_ops.cc
new file mode 100644
index 0000000..5463a7a
--- /dev/null
+++ b/tensorflow/core/ops/ktfop_ops.cc
@@ -0,0 +1,81 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPFusedEmbedding")
+    .Input("weights: float")
+    .Input("lookup: int64")
+    .Input("dense_shape: int64")
+    .Input("indices: int64")
+    .Output("output: float")
+    .Attr("combiner: int")
+
+    .SetShapeFn([](InferenceContext* ctx) {
+      ShapeHandle temp;
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(1), 1, &temp));
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(3), 2, &temp));
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(2), 1, &temp));
+      ShapeHandle emb_var_shape;
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 2, &emb_var_shape));
+
+      DimensionHandle emb_size_dim = ctx->Dim(emb_var_shape, 1);
+      DimensionHandle batch_dim = ctx->UnknownDim();
+
+      ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+      ctx->set_output(0, output_shape);
+
+      return OkStatus();
+    });
+
+REGISTER_OP("KPFusedEmbeddingWithHashBucket")
+    .Input("lookup: string")
+    .Input("weights: T_weight")
+    .Attr("num_buckets: int >= 1")
+    .Attr("combiner: int")
+    .Attr("T_weight: {resource, float}")
+    .Output("output: float")
+    .SetShapeFn([](InferenceContext* ctx) {
+      ShapeHandle temp;
+      TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &temp));
+      DimensionHandle emb_size_dim = ctx->UnknownDim();
+      DimensionHandle batch_dim = ctx->UnknownDim();
+
+      ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+      ctx->set_output(0, output_shape);
+
+      return OkStatus();
+    });
+    
+REGISTER_OP("KPSoftmax")
+    .Input("logits: T")
+    .Output("softmax: T")
+    .Attr("T: {float}")
+    .SetShapeFn([](InferenceContext* c) {
+      return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
+    });
+
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/core/platform/numa.h b/tensorflow/core/platform/numa.h
index 6333c01..f68b046 100644
--- a/tensorflow/core/platform/numa.h
+++ b/tensorflow/core/platform/numa.h
@@ -30,6 +30,7 @@ using tsl::port::NUMAGetThreadNodeAffinity;
 using tsl::port::NUMAMalloc;
 using tsl::port::NUMANumNodes;
 using tsl::port::NUMASetThreadNodeAffinity;
+using tsl::port::ThreadAffinity;
 }  // namespace port
 }  // namespace tensorflow
 #endif  // TENSORFLOW_CORE_PLATFORM_NUMA_H_
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index de23bad..85a8fc0 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -17,6 +17,13 @@ option java_multiple_files = true;
 option java_package = "org.tensorflow.framework";
 option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
 
+enum ExecutorPolicy {
+  // Tensorflow default executor
+  USE_NORMAL_EXECUTOR = 0;
+  // Kunpeng batch scheduling executor
+  USE_BATCH_SCHEDULING_EXECUTOR = 1;
+}
+
 message GPUOptions {
   // Fraction of the total GPU memory to allocate for each process.
   // 1 means to allocate all of the GPU memory, 0.5 means the process
@@ -716,7 +723,8 @@ message ConfigProto {
 
   Experimental experimental = 16;
 
-  // Next: 18
+  bool use_batch_op_scheduling = 18;
+  // Next: 19
 }
 
 // Options for a single Run() call.
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index bcc5f6a..f1c7598 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -148,4 +148,31 @@ bool IsZenDnnEnabled() {
 #endif  // !AMD_ZENDNN
 }
 
+bool IsKDNNEnabled() {
+#ifndef ENABLE_KDNN
+  return false;
+#else
+  static absl::once_flag once;
+  static bool KDNN_enabled = true;
+  absl::call_once(once, [&] {
+    auto status = ReadBoolFromEnvVar("TF_ENABLE_KDNN_OPTS", KDNN_enabled,
+                                     &KDNN_enabled);
+
+    if (!status.ok()) {
+      LOG(WARNING) << "TF_ENABLE_KDNN_OPTS is not set to either '0', 'false',"
+                   << " '1', or 'true'. Using the default setting: "
+                   << KDNN_enabled;
+    }
+    if (KDNN_enabled) {
+      LOG(INFO) << "KDNN custom operations are on. "
+                << "You may see slightly different numerical results due to "
+                << "floating-point round-off errors from different computation "
+                << "orders. To turn them off, set the environment variable "
+                << "`TF_ENABLE_KDNN_OPTS=0`.";
+    }
+  });
+  return KDNN_enabled;
+#endif  // !KDNN
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
index d0a755c..ea13ffb 100644
--- a/tensorflow/core/util/port.h
+++ b/tensorflow/core/util/port.h
@@ -47,6 +47,8 @@ bool IsMklEnabled();
 // Returns true if TF_ENABLE_ZENDNN_OPTS is set to 1
 bool IsZenDnnEnabled();
 
+// Returns true if TF_ENABLE_KDNN_OPTS is set to 1
+bool IsKDNNEnabled();
 }  // end namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_UTIL_PORT_H_
diff --git a/tensorflow/lite/g3doc/images/build/build_workflow_diag.png b/tensorflow/lite/g3doc/images/build/build_workflow_diag.png
index efd30b1..e69de29 100644
Binary files a/tensorflow/lite/g3doc/images/build/build_workflow_diag.png and b/tensorflow/lite/g3doc/images/build/build_workflow_diag.png differ
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/framework/_init_.py b/tensorflow/python/kernel_tests/benchmark/framework/_init_.py
new file mode 100644
index 0000000..e69de29
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/framework/runner.py b/tensorflow/python/kernel_tests/benchmark/framework/runner.py
new file mode 100644
index 0000000..df906c9
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/framework/runner.py
@@ -0,0 +1,329 @@
+import os
+import time
+import shutil
+import glob
+import json
+import sys
+import re
+import struct
+import tensorflow.compat.v1 as tf
+from dataclasses import dataclass, field
+from typing import List, Callable, Dict, Any, Union, Optional
+from tensorflow.python.profiler.internal import _pywrap_profiler
+from tensorflow.python.profiler import profiler_v2 as profiler
+from tensorflow.python.profiler.profiler_v2 import ProfilerOptions
+from tensorflow.core.profiler.protobuf.xplane_pb2 import XSpace
+import numpy as np
+import multiprocessing as mp
+import pickle
+
+tf.disable_eager_execution()
+
+
+@dataclass
+class TestCase:
+    name: str
+    op_fn: Callable
+    input_fn: Callable
+    check_fn: Optional[Callable] = None
+    num_iters: int = 0
+    optimize_percent: int = 0
+    operator_name: str = ""
+    meta: Dict = field(default_factory=dict)
+
+class CheckFuncClass:
+    def check_fn_float(A, B, meta):
+        atol = meta.get('atol')
+        rtol = meta.get('rtol')
+        if rtol is None:
+            rtol = 1e-3
+        if atol is None:
+            atol = 1e-4
+        # ∣a−b∣≤ atol + rtol ×∣b∣
+        return np.allclose(A, B, rtol=rtol, atol=atol)
+
+    def check_fn_equal(A, B, meta):
+        np.testing.assert_array_equal(A, B) # 如果失败抛出异常, 如果成功无返回值,
+        return True
+
+class UniversalOpBenchmark:
+    def __init__(self, log_dir="bench_logs"):
+        self.log_dir = log_dir
+        self.results = []
+
+    def _subprocess_worker(
+        self,
+        queue: mp.Queue,
+        func: Callable,
+        is_kdnn_enable: bool,
+        raw_inputs: list,
+        func_args: TestCase,
+        temp_dir: str
+    ) -> None:
+        """
+        子进程工作函数:创建临时目录,执行目标函数,将结果/异常放入队列
+        """
+        try:
+            # 1. 自动创建临时目录(不存在则创建,存在则不报错)
+            os.makedirs(temp_dir, exist_ok=True)
+            
+            # 2. 执行目标函数A
+            result = func(is_kdnn_enable, raw_inputs, func_args)
+            try:
+                pickle.dumps(result)
+            except Exception as e:
+                raise Exception(f"返回结果无法序列化,无法传递给主进程:{str(e)}") from e
+            # 3. 将执行结果放入队列(传递给主进程)
+            try:
+                queue.put(('success', result), block=True, timeout=10)
+            except mp.Queue.Full:
+                raise Exception("队列缓冲区已满,无法写入执行结果")
+        
+        except Exception as e:
+            # 捕获所有异常,传递给主进程
+            queue.put(('error', e))
+            return
+
+    def run_func_in_subprocess(
+        self,
+        func_A: Callable,
+        is_kdnn_enable: bool,
+        raw_inputs: list,
+        func_args: TestCase,
+        temp_dir: str = "./temp_workdir"
+    ) -> str | np.ndarray:
+        
+        # 2. 创建进程间通信队列(用于传递子进程执行结果/异常)
+        result_queue = mp.Queue()
+        
+        # 3. 构建子进程
+        sub_process = mp.Process(
+            target=self._subprocess_worker,
+            args=(result_queue, func_A, is_kdnn_enable, raw_inputs, func_args, temp_dir)
+        )
+        
+        # 4. 启动并等待子进程执行完成
+        sub_process.start()
+        
+        # 5. 从队列中获取子进程执行结果
+        process_status, result = result_queue.get()
+        if process_status == 'error':
+            raise Exception(f"子进程中函数执行失败:{str(result)}") from result
+
+        sub_process.join()  # 阻塞主进程,等待子进程退出
+        return result
+
+    def _get_event_name_by_metadata_id(self, metadata_id, plane):
+        if not hasattr(plane, 'event_metadata'):
+            return f"unknown_operator_{metadata_id}"
+
+        metadata=plane.event_metadata[metadata_id]
+        if metadata.id == metadata_id:
+            return metadata.name if (metadata.name and metadata.name.strip()) else f"unknown_operator_{metadata_id}"
+        
+        # 未匹配到返回默认名称
+        return f"unknown_operator_{metadata_id}"
+
+    def parse_xplane_file_optimized(self, dir_path):
+        if not os.path.isdir(dir_path):
+            print(f"dir is not exist: {dir_path}")
+            return None
+
+        xplane_files = []
+        for file_name in os.listdir(dir_path):
+            if file_name.endswith(".xplane.pb"):
+                xplane_files.append(os.path.join(dir_path, file_name))
+
+        if not xplane_files:
+            print(f"file is not exist in {dir_path}")
+            return None
+        elif len(xplane_files) > 1:
+            print(f"Found multiple.xplane.pb files: {xplane_files}; the first file will be used.")
+
+        file_path = xplane_files[0]
+        with open(file_path, 'rb') as f:
+            content = f.read()
+    
+        xspace = XSpace()
+        xspace.ParseFromString(content)
+        if len(xspace.planes) > 0:
+            return xspace
+
+        print("parse_xplane_file_optimized failed")
+        return None
+
+    def get_average_wall_during(self, input_dir, operator_name, test_case: TestCase):
+        # 1.提取数据
+        xplane_data = self.parse_xplane_file_optimized(input_dir)
+        if not xplane_data:
+            return 0.0
+
+        op_name_list = []
+        # 2. 获取wall_duration
+        wall_durations_us_list = []
+        for plane in xplane_data.planes:
+            if hasattr(plane, 'lines') and len(plane.lines) > 0:
+                for line in plane.lines:
+                    if hasattr(line, 'events') and len(line.events) > 0:
+                        for event in line.events:
+                            # 1. 获取事件时长(ps转换为us)
+                            duration_ps = event.duration_ps if hasattr(event, 'duration_ps') else 0
+                            if duration_ps <= 0:
+                                continue  # 过滤无效时长事件
+                            #转换成us
+                            wall_duration_us = duration_ps / 1000000
+
+                            # 2. 通过metadata_id获取算子名称
+                            metadata_id = event.metadata_id if hasattr(event, 'metadata_id') else 0
+                            op_name = self._get_event_name_by_metadata_id(metadata_id, plane)
+                            if op_name not in op_name_list:
+                                op_name_list.append(op_name)
+                            if operator_name in op_name:
+                                wall_durations_us_list.append(wall_duration_us)
+        # print(f"解析到的算子名称列表:{op_name_list}") # 打印所有算子名称
+        if test_case.num_iters != np.size(wall_durations_us_list):
+            print(f"解析到有效数据{np.size(wall_durations_us_list)}条, 预期{test_case.num_iters}条")
+            return 0.0
+
+        return float(np.mean(wall_durations_us_list)), np.var(wall_durations_us_list, ddof=1)
+
+    def parse_performance_data(self, kdnn_enable, raw_inputs, test_case: TestCase):
+        print(f"Testing: {test_case.name} ...")
+        os.environ['TF_ENABLE_KDNN_OPTS'] = str(kdnn_enable)
+        
+        placeholders = []
+        feed_dict = {}
+        
+        for i, val in enumerate(raw_inputs):
+            np_val = np.array(val)
+            
+            p = tf.placeholder(
+                dtype=tf.as_dtype(np_val.dtype),
+                shape=np_val.shape,
+                name=f"input_{i}"
+            )
+            placeholders.append(p)
+            feed_dict[p] = np_val
+
+        res_node = test_case.op_fn(placeholders, test_case.meta)
+        config = tf.ConfigProto(
+            inter_op_parallelism_threads=16,
+            intra_op_parallelism_threads=16
+        )
+        options = ProfilerOptions(
+            host_tracer_level=2,
+            device_tracer_level=1,
+            python_tracer_level=0  # 建议保持 0 以减少对 Kernel 测量的干扰
+        )
+        with tf.Session(config=config) as sess:
+            for _ in range(10):
+                sess.run(res_node, feed_dict=feed_dict)
+
+            profiler.start(self.log_dir, options)
+            time.sleep(3)
+            for _ in range(test_case.num_iters):
+                sess.run(res_node, feed_dict=feed_dict)
+            profiler.stop()
+
+        try:
+            kdnn_status = os.environ.get('TF_ENABLE_KDNN_OPTS', '1')
+            tag = f"{test_case.name}_kdnn_{kdnn_status}"  # 结果为 kdnn_0 或 kdnn_1
+            profile_root = os.path.join(self.log_dir, "plugins", "profile")
+            timestamp_dirs = [d for d in os.listdir(profile_root) 
+                             if os.path.isdir(os.path.join(profile_root, d)) and d.startswith("202")]
+            
+            if timestamp_dirs:
+                src_dir = os.path.join(profile_root, timestamp_dirs[0])
+                # 目标路径: bench_logs/MatMul_Test/plugins/profile/kdnn_1
+                target_dir = os.path.join(profile_root, tag)
+                
+                if os.path.exists(target_dir):
+                    shutil.rmtree(target_dir)
+                
+                os.rename(src_dir, target_dir)
+                print(f"✅ Profile 已固化至: {target_dir}")
+                return target_dir
+
+                # 可选:清理掉空的时间戳父目录或日志文件
+                # (TensorFlow 有时会留下一些空的 event 文件,通常不影响解析)
+        except Exception as e:
+            print(f"⚠️ 路径固化失败: {e}")
+
+    def run_performance_test(self, test_case: TestCase):
+        if test_case.num_iters == 0:
+            return True
+
+        print(f"Performance testing: {test_case.name} ...")
+
+        optimize_percent = test_case.optimize_percent
+        operator_name = test_case.operator_name
+        raw_inputs = test_case.input_fn()
+        if not isinstance(raw_inputs, (list, tuple)):
+            raw_inputs = [raw_inputs]
+
+        no_kdnn_path = self.run_func_in_subprocess(self.parse_performance_data, 0, raw_inputs, test_case)
+        no_kdnn_wall_during, _ = self.get_average_wall_during(no_kdnn_path, operator_name, test_case)
+        kdnn_path = self.run_func_in_subprocess(self.parse_performance_data, 1, raw_inputs, test_case)
+        kdnn_wall_during, _ = self.get_average_wall_during(kdnn_path, operator_name, test_case)
+        if kdnn_wall_during <= 0.0 or no_kdnn_wall_during <= 0.0:
+            print(f"⚠️ 获取平均运行时长失败, 无法进行比较")
+            return False
+        real_percent = 100 * (no_kdnn_wall_during/kdnn_wall_during - 1)
+        if real_percent < optimize_percent:
+            raise Exception(f"⚠️ 性能提升{real_percent:.2f}%, 低于预期{optimize_percent}%, KDNN开启状态下耗时{kdnn_wall_during:.2f}us, 关闭状态下耗时{no_kdnn_wall_during:.2f}us")
+            return False
+
+        print(f"性能测试通过, 性能提升{real_percent:.2f}%, KDNN开启状态下耗时{kdnn_wall_during:.2f}us, 关闭状态下耗时{no_kdnn_wall_during:.2f}us")
+        return True
+
+    def run_function_test(self, test_case: TestCase):
+        print(f"Function testing: {test_case.name} ...")
+        
+        raw_inputs = test_case.input_fn()
+        if not isinstance(raw_inputs, list):
+            raw_inputs = [raw_inputs]
+
+        def execute_variant(enable_kdnn, raw_inputs, test_case: TestCase):
+            os.environ['TF_ENABLE_KDNN_OPTS'] = str(enable_kdnn)
+            tf.reset_default_graph()
+
+            placeholders = []
+            feed_dict = {}
+
+            for i, val in enumerate(raw_inputs):
+                np_val = np.array(val)
+                
+                p = tf.placeholder(
+                    dtype=tf.as_dtype(np_val.dtype),
+                    shape=np_val.shape,
+                    name=f"input_{i}"
+                )
+                placeholders.append(p)
+                feed_dict[p] = np_val
+            res_node = test_case.op_fn(placeholders, test_case.meta)
+            config = tf.ConfigProto(
+                inter_op_parallelism_threads=16,
+                intra_op_parallelism_threads=16
+            )
+            options = ProfilerOptions(
+                host_tracer_level=2,
+                device_tracer_level=1,
+                python_tracer_level=0  # 建议保持 0 以减少对 Kernel 测量的干扰
+            )
+            with tf.Session(config=config) as sess:
+                return sess.run(res_node, feed_dict=feed_dict)
+            return
+
+        kdnn_data = self.run_func_in_subprocess(execute_variant, 1, raw_inputs, test_case)
+        no_kdnn_data = self.run_func_in_subprocess(execute_variant, 0, raw_inputs, test_case)
+        if test_case.check_fn is None:
+            is_correct = CheckFuncClass.check_fn_float(no_kdnn_data, kdnn_data, test_case.meta)
+        else:
+            is_correct = test_case.check_fn(no_kdnn_data, kdnn_data, test_case.meta)
+
+        if not is_correct:
+            print(f"⚠️ 误差较大")
+            raise Exception(f"⚠️ 误差较大")
+            return False
+        print("功能测试通过")
+        return True
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/main.py b/tensorflow/python/kernel_tests/benchmark/main.py
new file mode 100644
index 0000000..70b3da5
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/main.py
@@ -0,0 +1,51 @@
+import argparse
+import importlib
+import pkgutil
+import traceback
+import os
+
+def main():
+    parser = argparse.ArgumentParser(description="TF Op Benchmark Framework with KDNN Control")
+    parser.add_argument('--op', type=str, help='指定要运行的算子模块名')
+    parser.add_argument('--list', action='store_true', help='列出所有模块')
+    parser.add_argument('--performance_test', choices=['True', 'False'], default='True', help='是否运行性能测试模块')
+
+    args = parser.parse_args()
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' #日志太多, 关掉
+
+    import tensorflow as TF
+    from framework.runner import UniversalOpBenchmark
+    import ops
+
+    # 4. 扫描模块
+    available_ops = {
+        name: f'ops.{name}' 
+        for loader, name, is_pkg in pkgutil.iter_modules(ops.__path__)
+    }
+
+    if args.list:
+        print("📁 可用算子列表:")
+        for name in available_ops: print(f"  - {name}")
+        return
+
+    target_modules = [available_ops[args.op]] if args.op else list(available_ops.values())
+    root_log_dir = "bench_logs"
+    bench = UniversalOpBenchmark(log_dir=root_log_dir)
+    
+    for module_path in target_modules:
+        try:
+            module = importlib.import_module(module_path)
+            if hasattr(module, 'get_test_cases'):
+                cases = module.get_test_cases()
+                for case in cases:
+                    bench.run_function_test(case)
+                if args.performance_test == "True":
+                    for case in cases:
+                        bench.run_performance_test(case)
+        except Exception as e:
+            print(f"❌ 运行模块 {module_path} 失败: {e}")
+            traceback.print_exc()
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/batchmatmul_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/batchmatmul_ops.py
new file mode 100644
index 0000000..0046966
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/batchmatmul_ops.py
@@ -0,0 +1,176 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    
+    def matmul_op(inputs, meta):
+        return tf.matmul(inputs[0], inputs[1], transpose_a=meta.get('trans_a', False), transpose_b=meta.get('trans_b', False))
+
+    def input_2D_float(m, k, n):
+        return [
+            np.random.uniform(0, 1, (m, k)).astype(np.float32), 
+            np.random.uniform(0, 1, (k, n)).astype(np.float32),
+        ]
+    
+    def input_3D_no_broadcast_float(b1, m, k, n):
+        return [
+            np.random.uniform(0, 1, (b1, m, k)).astype(np.float32), 
+            np.random.uniform(0, 1, (b1, k, n)).astype(np.float32),
+        ]
+    
+    def input_3D_broadcast_float(b1, m, k, n):
+        return [
+            np.random.uniform(0, 1, (b1, m, k)).astype(np.float32), 
+            np.random.uniform(0, 1, (1, k, n)).astype(np.float32),
+        ]
+
+    def build_batch_matmul_input(shape_a, shape_b, trans_a, trans_b, dtype):
+        """
+        构造满足矩阵乘法条件的两个输入 Tensor
+        """
+        # 如果转置,需要交换最后两个维度来生成原始数据
+        final_shape_a = list(shape_a)
+        if trans_a:
+            final_shape_a[-1], final_shape_a[-2] = final_shape_a[-2], final_shape_a[-1]
+            
+        final_shape_b = list(shape_b)
+        if trans_b:
+            final_shape_b[-1], final_shape_b[-2] = final_shape_b[-2], final_shape_b[-1]
+
+        a = np.random.uniform(-1, 1, final_shape_a).astype(dtype)
+        b = np.random.uniform(-1, 1, final_shape_b).astype(dtype)
+        return [a, b]
+
+    class BatchMatMulTestCaseFactory:
+        @staticmethod
+        def create_func_test():
+            # 基础组合维度定义
+            dims_to_test = [2, 3, 4, 5]
+            trans_options = [True, False]
+            broadcast_options = [True, False]
+            dtypes = {"fp32": np.float32}
+            
+            # 矩阵乘法的核心维度 M, K, N
+            M, K, N = 64, 32, 48
+            
+            all_cases = []
+            
+            for dim in dims_to_test:
+                for is_bc in broadcast_options:
+                    # 2D 场景不存在广播概念,跳过重复
+                    if dim == 2 and is_bc: continue 
+                    
+                    for ta in trans_options:
+                        for tb in trans_options:
+                            for dt_name, dt_val in dtypes.items():
+                                
+                                # 构造形状逻辑
+                                # A: (..., M, K), B: (..., K, N)
+                                # 如果转置参数为真,op内部会处理,但输入生成需匹配
+                                batch_dims_a = [2, 3, 1][:dim-2] if dim > 2 else []
+                                if is_bc:
+                                    # 构造广播场景:A的batch维有1,B的batch维正常
+                                    batch_dims_a = [1] * (dim - 2)
+                                    batch_dims_b = [2] * (dim - 2)
+                                else:
+                                    batch_dims_b = batch_dims_a
+                                    
+                                shape_a = tuple(batch_dims_a + [M, K])
+                                shape_b = tuple(batch_dims_b + [K, N])
+                                
+                                case_name = f"batch_matmul_{dim}D_{dt_name}_bc{is_bc}_ta{ta}_tb{tb}"
+                                
+                                # 预绑定输入构造函数
+                                input_fn = partial(
+                                    build_batch_matmul_input,
+                                    shape_a=shape_a,
+                                    shape_b=shape_b,
+                                    trans_a=ta,
+                                    trans_b=tb,
+                                    dtype=dt_val
+                                )
+                                
+                                all_cases.append(TestCase(
+                                    name=case_name,
+                                    op_fn=matmul_op,
+                                    input_fn=input_fn,
+                                    num_iters=0,
+                                    meta={
+                                        'trans_a': ta, # 对应算子属性:是否转置第一个输入
+                                        'trans_b': tb  # 对应算子属性:是否转置第二个输入
+                                    }
+                                ))
+            return all_cases
+    test_case = []
+    test_case.extend(BatchMatMulTestCaseFactory.create_func_test())
+    # perf test case
+    test_case.extend([
+        TestCase(
+            name="MatMul_2D_79_1570_256_False_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_2D_float, m=79, k=1570, n=256),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_2D_79_1570_128_False_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_2D_float, m=79, k=1570, n=128),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_2D_4480_32_16_False_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_2D_float, m=4480, k=32, n=16),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_2D_64_256_128_False_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_2D_float, m=64, k=256, n=128),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_2D_128_592_128_False_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_2D_float, m=128, k=592, n=128),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_3D_No_Broadcast_64_64_256_256_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_3D_no_broadcast_float, b1=64, m=64, k=256, n=256),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        TestCase(
+            name="MatMul_3D_No_Broadcast_64_32_64_64_False",
+            op_fn=matmul_op,
+            input_fn=partial(input_3D_no_broadcast_float, b1=64, m=32, k=64, n=64),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "MatMul",
+            meta={'trans_a': False, 'trans_b': False},
+        ),
+        
+    ])
+    return test_case
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/concat_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/concat_ops.py
new file mode 100644
index 0000000..4cbc14c
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/concat_ops.py
@@ -0,0 +1,278 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase, CheckFuncClass
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    def concat_op(inputs, meta):
+        return tf.concat(inputs, meta["axis"])
+
+    def build_18input_2D(dtype=np.float32):
+        batch = 150
+        shapes = [
+            (batch, 960),
+            (batch, 24),
+            (batch, 4),
+            (batch, 4),
+            (batch, 4),
+            (batch, 64),
+            (batch, 4),
+            (batch, 4),
+            (batch, 32),
+            (batch, 4),
+            (batch, 4),
+            (batch, 24),
+            (batch, 4),
+            (batch, 4),
+            (batch, 24),
+            (batch, 4),
+            (batch, 4),
+            (batch, 16),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+    def build_12input_2D(dtype=np.float32):
+        batch = 150
+        shapes = [
+            (batch, 192),
+            (batch, 288),
+            (batch, 64),
+            (batch, 48),
+            (batch, 48),
+            (batch, 8),
+            (batch, 48),
+            (batch, 48),
+            (batch, 8),
+            (batch, 256),
+            (batch, 128),
+            (batch, 128),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+    def build_8input_3D(dtype=np.float32):
+        batch = 16
+        seq = 64
+        shapes = [
+            (batch, seq, 32),
+            (batch, seq, 64),
+            (batch, seq, 16),
+            (batch, seq, 16),
+            (batch, seq, 16),
+            (batch, seq, 16),
+            (batch, seq, 16),
+            (batch, seq, 16),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+    def build_4input_2D(dtype=np.float32):
+        batch = 256
+        shapes = [
+            (batch, 192),
+            (batch, 288),
+            (batch, 64),
+            (batch, 48),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+    
+    def build_3input_2D(dtype=np.float32):
+        shapes = [
+            (128, 1510),
+            (128, 36),
+            (128, 24),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+    
+    def build_2input_2D(batch, input1_dim, input2_dim, dtype=np.float32):
+        shapes = [
+            (batch, input1_dim),
+            (batch, input2_dim),
+        ]
+        return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+    def build_dynamic_input(num_inputs, dtype, shape_template, concat_axis):
+        """
+        通用输入构造器
+        :param num_inputs: 输入 Tensor 的个数 (1, 2, 8, 12, 18)
+        :param dtype: 数据类型 (np.float32, np.int32 等)
+        :param shape_template: 基础 3D 形状,例如 (batch, seq, feature)
+        :param concat_axis: 拼接的轴,用于微调形状确保可以 concat
+        """
+        inputs = []
+        for i in range(num_inputs):
+            curr_shape = list(shape_template)
+            # 为了模拟真实情况,我们可以让拼接轴上的长度略有不同
+            curr_shape[concat_axis] = np.random.randint(1, 10) # 如果需要动态长度可开启
+            
+            if np.issubdtype(dtype, np.integer):
+                data = np.random.randint(0, 100, curr_shape).astype(dtype)
+            else:
+                data = np.random.uniform(0, 1, curr_shape).astype(dtype)
+            inputs.append(data)
+        return inputs
+
+    class ConcatTestCaseFactory:
+        @staticmethod
+        def create_func_test():
+            all_cases = []
+            input_counts = [1, 2, 8, 12, 18]
+            dtypes = {
+                "fp32": np.float32, 
+                "bf16": np.uint16, # 工业界常用 uint16 模拟 bf16 存储
+                "int32": np.int32, 
+                "int8": np.int8, 
+                "uint8": np.uint8
+            }
+            axes = [0, 1, 2, -1, -2, -3]
+            base_3d_shape = (32, 64, 128) # 示例 3D 基础形状
+            for num in input_counts:
+                for dtype_name, dtype_val in dtypes.items():
+                    for axis in axes:
+                        case_name = f"concat_3D_{num}in_{dtype_name}_axis{axis}"
+                        
+                        # 使用 partial 预绑定所有参数
+                        input_fn = partial(
+                            build_dynamic_input, 
+                            num_inputs=num, 
+                            dtype=dtype_val, 
+                            shape_template=base_3d_shape, 
+                            concat_axis=axis
+                        )
+                        
+                        case = TestCase(
+                            name=case_name,
+                            op_fn=concat_op,
+                            input_fn=input_fn,
+                            num_iters=0,
+                            check_fn=CheckFuncClass.check_fn_equal,
+                            meta={'axis': axis}
+                        )
+                        all_cases.append(case)
+            return all_cases
+    
+    test_case = []
+    test_case.extend(ConcatTestCaseFactory.create_func_test())
+    # perf test case
+    test_case.extend([
+        TestCase(
+            name="concat_18input_2D_float",
+            op_fn=concat_op,
+            input_fn=partial(build_18input_2D, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_12input_2D_float",
+            op_fn=concat_op,
+            input_fn=partial(build_12input_2D, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_8input_3D_float",
+            op_fn=concat_op,
+            input_fn=partial(build_8input_3D, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_4input_2D_float",
+            op_fn=concat_op,
+            input_fn=partial(build_4input_2D, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_3input_2D_float",
+            op_fn=concat_op,
+            input_fn=partial(build_3input_2D, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_2input_2D_small_float",
+            op_fn=concat_op,
+            input_fn=partial(build_2input_2D, batch=128, input1_dim=400, input2_dim=400, dtype=np.float32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_18input_2D_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_18input_2D, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_12input_2D_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_12input_2D, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_8input_3D_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_8input_3D, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_4input_2D_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_4input_2D, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_3input_2D_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_3input_2D, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+        TestCase(
+            name="concat_2input_2D_small_int32",
+            op_fn=concat_op,
+            input_fn=partial(build_2input_2D, batch=128, input1_dim=400, input2_dim=400, dtype=np.int32),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "ConcatV2",
+            check_fn=CheckFuncClass.check_fn_equal,
+            meta={'axis': -1}
+        ),
+    ])
+    return test_case
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/einsum_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/einsum_ops.py
new file mode 100644
index 0000000..7ad3695
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/einsum_ops.py
@@ -0,0 +1,118 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    
+    def einsum_op(inputs, meta):
+        return tf.einsum(meta["label"], inputs[0], inputs[1])
+
+    def input_3D_bh_bsh_bs_float(B, H, S):
+        return [
+            np.random.uniform(0, 1, (B, H)).astype(np.float32),
+            np.random.uniform(0, 1, (B, S, H)).astype(np.float32),
+        ]
+    
+    def input_3D_bs_bsh_bh_float(B, H, S):
+        return [
+            np.random.uniform(0, 1, (B, S)).astype(np.float32),
+            np.random.uniform(0, 1, (B, S, H)).astype(np.float32),
+        ]
+    
+    def build_einsum_input(shapes, dtype):
+        """
+        根据 einsum 表达式和形状列表构造输入
+        """
+        inputs = []
+        for shape in shapes:
+            if np.issubdtype(dtype, np.integer):
+                data = np.random.randint(0, 10, shape).astype(dtype)
+            else:
+                data = np.random.uniform(-1, 1, shape).astype(dtype)
+            inputs.append(data)
+        return inputs
+
+    class EinsumTestCaseFactory:
+        @staticmethod
+        def create_func_test():
+            all_cases = []
+            # 定义 B, H, S 的测试点
+            B_list = [1, 79, 256]
+            H_list = [1, 79, 128]
+            S_list = [1, 79, 512]
+            
+            # 1. Scoring 场景: bh, bsh -> bs
+            for b in B_list:
+                for h in H_list:
+                    for s in S_list:
+                        eq = "bh,bsh->bs"
+                        shapes = [(b, h), (b, s, h)]
+                        all_cases.append(EinsumTestCaseFactory._create_case(
+                            "Scoring", eq, shapes, np.float32, b, h, s
+                        ))
+
+            # 2. Pooling 场景: bs, bsh -> bh
+            for b in B_list:
+                for h in H_list:
+                    for s in S_list:
+                        eq = "bs,bsh->bh"
+                        shapes = [(b, s), (b, s, h)]
+                        all_cases.append(EinsumTestCaseFactory._create_case(
+                            "Pooling", eq, shapes, np.float32, b, h, s
+                        ))
+
+            # 3. 批量矩阵乘法场景: bij, bjk -> bik
+            ikj_variants = [
+                (64, 128, 64), (79, 128, 64), (1, 128, 64), 
+                (128, 1, 64), (128, 128, 1)
+            ]
+            for b in B_list:
+                for i, k, j in ikj_variants:
+                    eq = "bij,bjk->bik"
+                    shapes = [(b, i, j), (b, j, k)]
+                    all_cases.append(EinsumTestCaseFactory._create_case(
+                        "BatchMatMul", eq, shapes, np.float32, b, i, j, k
+                    ))
+            
+            return all_cases
+
+        @staticmethod
+        def _create_case(scene, eq, shapes, dtype, *args):
+            dims_str = "_".join([str(a) for a in args])
+            case_name = f"einsum_{scene}_{eq.replace(',', '_').replace('->', '_')}_params_{dims_str}"
+            
+            return TestCase(
+                name=case_name,
+                op_fn=einsum_op,
+                input_fn=partial(build_einsum_input, shapes=shapes, dtype=dtype),
+                num_iters=0,
+                operator_name="Einsum",
+                meta={'label': eq}
+            )
+
+    test_case = []
+    test_case.extend(EinsumTestCaseFactory.create_func_test())
+    # perf test case
+    test_case.extend([
+        TestCase(
+            name="Einsum_3D_float_bh_bsh_bs_128_256_512",
+            op_fn=einsum_op,
+            input_fn=partial(input_3D_bh_bsh_bs_float, B=128, H=256, S=512),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "Einsum",
+            meta={'label': "bh,bsh->bs"}
+        ),
+        TestCase(
+            name="Einsum_3D_float_bs_bsh_bh_128_128_512",
+            op_fn=einsum_op,
+            input_fn=partial(input_3D_bs_bsh_bh_float, B=128, H=128, S=512),
+            num_iters=1000,
+            optimize_percent = 5,
+            operator_name = "Einsum",
+            meta={'label': "bs,bsh->bh"}
+        ),
+    ])
+    return test_case
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/floormod_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/floormod_ops.py
new file mode 100644
index 0000000..fe92a36
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/floormod_ops.py
@@ -0,0 +1,67 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+from framework.runner import CheckFuncClass
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    
+    def reduce_floormod_op(inputs, meta):
+        return tf.math.floormod(inputs[0], inputs[1])
+
+    def build_input_1D_float():
+        shape = [500000]
+        return [np.random.uniform(-1, 1, shape).astype(np.float32) for i in range(2)]
+
+    def build_input_3D_float():
+        shape = [64, 64, 128]
+        return [np.random.uniform(0, 1, shape).astype(np.float32) for i in range(2)]
+
+    def build_input_1D_int64():
+        shape = [1000000]
+        return [np.random.randint(-2**63, 2**63-1, size = shape, dtype=np.int64) for i in range(2)]
+
+    def build_input_3D_int64():
+        shape = [128, 64, 128]
+        return [np.random.randint(0, 2**63-1, size = shape, dtype=np.int64) for i in range(2)]
+
+    return [
+        TestCase(
+            name="floormod_input_1D_float",
+            op_fn=reduce_floormod_op,
+            input_fn=build_input_3D_float,
+            num_iters=1000,
+            optimize_percent = 150,
+            operator_name = "FloorMod:FloorMod",
+            meta={}
+        ),
+        TestCase(
+            name="floormod_input_3D_float",
+            op_fn=reduce_floormod_op,
+            input_fn=build_input_3D_float,
+            num_iters=1000,
+            optimize_percent = 150,
+            operator_name = "FloorMod:FloorMod",
+            meta={}
+        ),
+        TestCase(
+            name="floormod_input_1D_int64",
+            op_fn=reduce_floormod_op,
+            input_fn=build_input_1D_int64,
+            check_fn=CheckFuncClass.check_fn_equal,
+            num_iters=1000,
+            optimize_percent = 25,
+            operator_name = "FloorMod:FloorMod",
+            meta={}
+        ),
+        TestCase(
+            name="floormod_input_3D_int64",
+            op_fn=reduce_floormod_op,
+            input_fn=build_input_3D_int64,
+            check_fn=CheckFuncClass.check_fn_equal,
+            num_iters=1000,
+            optimize_percent = 25,
+            operator_name = "FloorMod:FloorMod",
+            meta={}
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/sigmoid_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/sigmoid_ops.py
new file mode 100644
index 0000000..c48e47e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/sigmoid_ops.py
@@ -0,0 +1,44 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+    
+    def sigmoid_op(inputs, meta):
+        return tf.nn.sigmoid(inputs)
+
+    def build_input_1D_float():
+        shape = [1024]
+        return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+    def build_input_3D_float():
+        shape = [128, 64, 128]
+        return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+    def build_input_empty():
+        return [np.array([]).astype(np.float32)]
+
+
+    return [
+        TestCase(
+            name="sigmoid_input_1D_float",
+            op_fn=sigmoid_op,
+            num_iters=0,
+            input_fn=build_input_1D_float
+        ),
+        TestCase(
+            name="sigmoid_input_3D_float",
+            op_fn=sigmoid_op,
+            num_iters=1000,
+            input_fn=build_input_3D_float,
+            optimize_percent = 50,
+            operator_name = "Sigmoid:Sigmoid"
+        ),
+        TestCase(
+            name="sigmoid_input_empty",
+            op_fn=sigmoid_op,
+            num_iters=0,
+            input_fn=build_input_empty
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/tensorflow/python/kernel_tests/benchmark/ops/softmax_ops.py b/tensorflow/python/kernel_tests/benchmark/ops/softmax_ops.py
new file mode 100644
index 0000000..f0157ef
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark/ops/softmax_ops.py
@@ -0,0 +1,53 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+    """每个算子文件都统一实现这个接口"""
+
+    def softmax_op(inputs, meta):
+        return tf.nn.softmax(inputs, axis=meta["axis"])
+
+    def logsoftmax_op(inputs, meta):
+        return tf.nn.log_softmax(inputs, axis=meta["axis"])
+
+    def build_input_2D_float():
+        shape = [1024, 128]
+        return [np.random.uniform(-100, 100, shape).astype(np.float32)]
+
+    def build_input_3D_float():
+        shape = [64, 64, 128]
+        return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+    return [
+        TestCase(
+            name="softmax_input_2D_float",
+            op_fn=softmax_op,
+            input_fn=build_input_2D_float,
+            num_iters=0,
+            meta={'axis': 2,}
+        ),
+        TestCase(
+            name="softmax_input_3D_float_axis_1",
+            op_fn=softmax_op,
+            input_fn=build_input_3D_float,
+            num_iters=0,
+            meta={'axis': 1,}
+        ),
+        TestCase(
+            name="softmax_input_3D_float_axis_2",
+            op_fn=softmax_op,
+            input_fn=build_input_3D_float,
+            num_iters=1000,
+            optimize_percent = 200,
+            operator_name = "Softmax:Softmax",
+            meta={'axis': 2,}
+        ),
+        TestCase(
+            name="logsoftmax_input_2D_float",
+            op_fn=logsoftmax_op,
+            input_fn=build_input_2D_float,
+            num_iters=0,
+            meta={'axis': -1,}
+        ),
+    ]
\ No newline at end of file
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 98adbcc..36239f2 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -61,6 +61,10 @@ load(
     "//third_party/compute_library:build_defs.bzl",
     "if_enable_acl",
 )
+load(
+    "//third_party/KDNN:build_defs.bzl",
+    "if_enable_kdnn",
+)
 load(
     "//third_party/llvm_openmp:openmp.bzl",
     "windows_llvm_openmp_linkopts",
@@ -465,6 +469,7 @@ def tf_copts(
         if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) +
         if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) +
         if_zendnn(["-DAMD_ZENDNN"]) +
+        if_enable_kdnn(["-DENABLE_KDNN"]) +
         if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) +
         if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) +
         if_linux_x86_64(["-msse3"]) +
diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl
index 8601192..e722d3f 100644
--- a/tensorflow/workspace2.bzl
+++ b/tensorflow/workspace2.bzl
@@ -26,6 +26,8 @@ load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
 load("//third_party/ducc:workspace.bzl", ducc = "repo")
 load("//third_party/eigen3:workspace.bzl", eigen3 = "repo")
 load("//third_party/farmhash:workspace.bzl", farmhash = "repo")
+load("//third_party/ktfop:workspace.bzl", ktfop = "repo")
+load("//third_party/kblas:workspace.bzl", kblas = "repo")
 load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
 load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo")
 load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo")
@@ -67,6 +69,8 @@ def _initialize_third_party():
     dlpack()
     eigen3()
     farmhash()
+    kblas()
+    ktfop()
     flatbuffers()
     gemmlowp()
     hexagon_nn()
@@ -803,7 +807,7 @@ def _tf_repositories():
         name = "upb",
         sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454",
         strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482",
-        patch_file = ["//third_party/grpc:upb_platform_fix.patch"],
+        patch_file = ["//third_party/grpc:upb_platform_fix.patch", "//third_party/grpc:upb_gcc10_compile_fix.patch"],
         urls = tf_mirror_urls("https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz"),
     )
 
diff --git a/tensorflow/third_party/KDNN/BUILD b/third_party/KDNN/BUILD
new file mode 100644
index 0000000..d3667ea
--- /dev/null
+++ b/third_party/KDNN/BUILD
@@ -0,0 +1,34 @@
+licenses(["notice"])
+
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+config_setting(
+    name = "enable_kdnn",
+    define_values = {
+        "enable_kdnn": "true",
+    },
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "kdnn",
+    hdrs = glob(["include/**/*.hpp"]),
+    includes = ["include"],
+    srcs = glob(["src/libkdnn.a"]),
+    linkopts = ["-lgomp"],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "kdnn_adapter",
+    hdrs = ["kdnn_adapter.h", "kdnn_threadpool.h", "kdnn_types_adapter.h", "kdnn_layout_adapter.h"],
+    strip_include_prefix = "/third_party/KDNN",
+    visibility = ["//visibility:public"],
+    deps = [":kdnn"],
+)
+
+bzl_library(
+    name = "build_defs_bzl",
+    srcs = ["build_defs.bzl"],
+    visibility = ["//visibility:public"],
+)
\ No newline at end of file
diff --git a/tensorflow/third_party/KDNN/build_defs.bzl b/third_party/KDNN/build_defs.bzl
new file mode 100644
index 0000000..da1748f
--- /dev/null
+++ b/third_party/KDNN/build_defs.bzl
@@ -0,0 +1,24 @@
+def if_enable_kdnn(if_true, if_false = []):
+    """Shorthand to select() if we are building with KDNN and KDNN is enabled.
+
+    This is only effective when built with KDNN.
+
+    Args:
+        if_true: expression to evaluate if building with KDNN and KDNN is enabled      
+        if_false: expression to evaluate if building without KDNN or KDNN is not enabled.
+
+    Returns:
+        A select evaluating to either if_true or if_false as appropriate.
+    """
+    return select({
+        "@org_tensorflow//third_party/KDNN:enable_kdnn": if_true,
+        "//conditions:default": if_false,
+    })
+
+def kdnn_deps():
+    """Shorthand for select() to pull in the correct set of KDNN library deps.
+    """
+    return select({
+        "@org_tensorflow//third_party/KDNN:enable_kdnn": ["//third_party/KDNN:kdnn_adapter"],
+        "//conditions:default": [],
+    })
\ No newline at end of file
diff --git a/tensorflow/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h
new file mode 100644
index 0000000..c72e963
--- /dev/null
+++ b/third_party/KDNN/kdnn_adapter.h
@@ -0,0 +1,289 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_ADAPTER_H_
+#include "kdnn.hpp"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/matmul_bcast.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "kdnn_threadpool.h"
+#include "kdnn_types_adapter.h"
+#include "kdnn_layout_adapter.h"
+#include "tensorflow/core/util/port.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "operations/kdnn_softmax.hpp"
+#include "operations/kdnn_eltwise.hpp"
+
+namespace tensorflow {
+
+inline void kdnnFusedGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+                    bool fusion_relu, bool trans_x, bool trans_y) {
+  int m = a.dim_size(0);
+  int n = b.dim_size(trans_y ? 0 : 1);
+  int k = b.dim_size(trans_y ? 1 : 0);
+  const float *A = a.flat<float>().data();
+  const float *B = b.flat<float>().data();
+  float *C = out->flat<float>().data();
+  const Tensor& bias = ctx->input(2);
+  const float *Bias = bias.flat<float>().data();
+  if (bias.dims() != 1 || bias.dim_size(0) != n) {
+    OP_REQUIRES_OK(ctx, errors::InvalidArgument("bias must be 1-dimensional and match n",
+                            bias.shape().DebugString()));
+  }
+  KDNN::PostOpsDataPtrs po_ptrs;
+  KDNN::PostOps post_ops;
+  if (fusion_relu) {
+    post_ops.AppendEltwise(KDNN::ActivationFunction::RELU);
+    po_ptrs.push_back(&post_ops);
+  }
+  // intra_op thread_pool
+  thread::ThreadPool* thread_pool = 
+    ctx->device()
+    ->tensorflow_cpu_worker_threads()
+    ->workers;
+  kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+  KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+  const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+  const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, trans_y ? KDNN::Layout::BA : KDNN::Layout::AB};
+  const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+  const KDNN::TensorInfo biasInfo = {{1, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+  KDNN::Attributes attr;
+  attr.SetPostOps(post_ops);
+  KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, biasInfo, attr);
+  gemm.Run(A, B, C, Bias, po_ptrs); 
+  KDNN::Threading::DeactivateThreadpool();
+}
+
+template<typename Tindices>
+inline void kdnnSparseMatmul(const std::size_t nnz,
+                      const std::size_t rhs_right, const std::size_t lhs_right,
+                      const int lhs_index_a, const int rhs_index_a,
+                      typename TTypes<float>::Matrix out,
+                      typename TTypes<Tindices>::ConstMatrix a_indices, 
+                      typename TTypes<float>::ConstVec a_values,
+                      const float* b_data) {
+    std::vector<int> idx(nnz);
+    int lhs_left = out.dimension(0);
+    std::vector<int> pntrb(lhs_left);
+    std::vector<int> pntre(lhs_left);
+    std::vector<int> row_counts(lhs_left);
+    for (size_t i = 0; i < nnz; ++i) {
+        idx[i] = a_indices(i, rhs_index_a);
+        ++row_counts[a_indices(i, lhs_index_a)];
+    }
+    
+    int current_pos = 0;
+    for (size_t i = 0; i < lhs_left; ++i) {
+        pntrb[i] = current_pos;
+        current_pos += row_counts[i];
+        pntre[i] = current_pos;
+    }
+    const KDNN::CsrSparseTensorInfo aInfo = {{lhs_left, lhs_right},
+        KDNN::Element::TypeT::F32, KDNN::Layout::AB, pntrb, pntre, idx, nnz};
+    const KDNN::TensorInfo bInfo = {{lhs_right, rhs_right},
+        KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+    const KDNN::TensorInfo dstInfo = {{lhs_left, rhs_right},
+        KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+    KDNN::SparseGemm sparse_csr(aInfo, bInfo, dstInfo);
+    sparse_csr.Run(a_values.data(), b_data, out.data());
+}
+
+template<typename T>
+inline void KDNNConcatImpl(OpKernelContext* ctx,
+                    const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+                    typename TTypes<T, 2>::Matrix* output) {
+  KDNN::Element::TypeT kdnnType = KDNN::Element::TypeAdapter<T>::value;
+  KDNN::Layout kdnnLayout = KDNN::LayoutAdapter<2, false>::value;
+  OP_REQUIRES(ctx, kdnnType != KDNN::Element::TypeT::UNDEFINED,
+    errors::InvalidArgument("unsupported kdnn data type"));
+  OP_REQUIRES(ctx, kdnnLayout != KDNN::Layout::UNDEFINED,
+    errors::InvalidArgument("unsupported kdnn layout"));
+  std::vector<KDNN::TensorInfo> inputInfos;
+  std::vector<const void *> input_ptrs;
+  inputInfos.reserve(inputs.size());
+  input_ptrs.reserve(inputs.size());
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    auto dim0 = inputs[i]->dimension(0);
+    auto dim1 = inputs[i]->dimension(1);
+    inputInfos.emplace_back(KDNN::TensorInfo{{dim0, dim1}, kdnnType, kdnnLayout});
+    input_ptrs.push_back(static_cast<const void*>(inputs[i]->data()));
+  }
+  void* output_ptr = static_cast<void *>(output->data());
+  thread::ThreadPool* thread_pool = 
+    ctx->device()
+    ->tensorflow_cpu_worker_threads()
+    ->workers;
+  kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+  KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+  KDNN::TensorInfo outputInfo({output->dimension(0), output->dimension(1)}, kdnnType, kdnnLayout);
+  KDNN::ConcatLayer concat(inputInfos, 1, outputInfo);
+  concat.Run(input_ptrs.data(), output_ptr);
+  KDNN::Threading::DeactivateThreadpool();
+}
+
+inline KDNN::TensorInfo MakeInfo(const tensorflow::Tensor* tensor, bool transposed) {
+  const tensorflow::TensorShape& shape = tensor->shape();
+  int dims = shape.dims();
+
+  std::vector<int64_t> d5 = {1, 1, 1, 1, 1};
+  for (int i = 0; i < dims; ++i) {
+    d5[4 - i] = shape.dim_size(dims - 1 - i);
+  }
+
+  if (transposed) {
+    std::swap(d5[3], d5[4]);
+  }
+
+  return KDNN::TensorInfo(
+    {d5[0], d5[1], d5[2], d5[3], d5[4]},
+    KDNN::Element::TypeT::F32,
+    transposed ? KDNN::Layout::ABCED : KDNN::Layout::ABCDE
+  );
+}
+
+inline KDNN::TensorInfo MakeOutputInfo(const KDNN::TensorInfo &tensorA, const KDNN::TensorInfo &tensorB) {
+  int dims = tensorA.GetNumDims();
+  std::vector<int64_t> d5 = {1, 1, 1, 1, 1};
+  for (int i = 0; i < dims - 2; ++i) {
+    d5[i] = std::max(tensorA.GetDims()[i], tensorB.GetDims()[i]);
+  }
+  d5[3] = tensorA.GetDims()[3];
+  d5[4] = tensorB.GetDims()[4];
+  return KDNN::TensorInfo(
+    {d5[0], d5[1], d5[2], d5[3], d5[4]},
+    KDNN::Element::TypeT::F32, KDNN::Layout::ABCDE
+  );
+}
+
+inline void kdnnGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+                     bool trans_x, bool trans_y) {
+  int m = a.dim_size(trans_x ? 2 : 1);
+  int n = b.dim_size(trans_y ? 1 : 2);
+  int k = b.dim_size(trans_y ? 2 : 1);
+  const float *A = a.flat<float>().data();
+  const float *B = b.flat<float>().data();
+  float *C = out->flat<float>().data();
+  thread::ThreadPool* thread_pool = 
+    ctx->device()
+    ->tensorflow_cpu_worker_threads()
+    ->workers;
+  kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+  KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+  const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, trans_x ? KDNN::Layout::BA : KDNN::Layout::AB};
+  const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, trans_y ? KDNN::Layout::BA : KDNN::Layout::AB};
+  const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+  KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo);
+  gemm.Run(A, B, C); 
+  KDNN::Threading::DeactivateThreadpool();
+}
+
+inline void kdnnBatchGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+                          bool trans_x, bool trans_y) {
+  const float *A = a.flat<float>().data();
+  const float *B = b.flat<float>().data();
+  float *C = out->flat<float>().data();
+  thread::ThreadPool* thread_pool = 
+    ctx->device()
+    ->tensorflow_cpu_worker_threads()
+    ->workers;
+  kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+  KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+  const KDNN::TensorInfo srcInfo = MakeInfo(&a, trans_x);
+  const KDNN::TensorInfo weightsInfo = MakeInfo(&b, trans_y);
+  const KDNN::TensorInfo dstInfo = MakeOutputInfo(srcInfo, weightsInfo);
+  KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo);
+  gemm.Run(A, B, C); 
+  KDNN::Threading::DeactivateThreadpool();
+}
+
+template <typename Functor>
+inline void kdnnFloormodOp(OpKernelContext* ctx, const Tensor &input_0, const Tensor &input_1, Tensor *output) {
+    typedef typename Functor::in_type Tin;    // Input scalar data type.
+    const Tin* src = input_0.flat<Tin>().data();
+    const Tin* src_1 = input_1.flat<Tin>().data();
+    Tin* dst = output->flat<Tin>().data();
+
+    KDNN::Shape tensorShape({input_0.shape().num_elements()});
+    thread::ThreadPool* thread_pool = 
+        ctx->device()
+        ->tensorflow_cpu_worker_threads()
+        ->workers;
+    kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+    KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+
+    if (std::is_same<Tin, int64_t>::value) {
+        KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+        KDNN::TensorInfo inputTensorInfo_1(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+        KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+        KDNN::BinaryLayer layer(inputTensorInfo, inputTensorInfo_1, outputTensorInfo, KDNN::BinaryFunction::FLOORMOD);
+        layer.Run(src, src_1, dst);
+    } else {
+        KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+        KDNN::TensorInfo inputTensorInfo_1(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+        KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+        KDNN::BinaryLayer layer(inputTensorInfo, inputTensorInfo_1, outputTensorInfo, KDNN::BinaryFunction::FLOORMOD);
+        layer.Run(src, src_1, dst);
+    }
+
+    KDNN::Threading::DeactivateThreadpool();
+    return;
+}
+
+template <typename Functor>
+inline void kdnnSigmoidOp(OpKernelContext* ctx, const Tensor &input, Tensor *output)
+{
+    typedef typename Functor::in_type Tin;
+    const Tin* src = input.flat<Tin>().data();
+    Tin* dst = output->flat<Tin>().data();
+    KDNN::Shape tensorShape({input.shape().num_elements()});
+    KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+    KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+
+    thread::ThreadPool* thread_pool = 
+        ctx->device()
+        ->tensorflow_cpu_worker_threads()
+        ->workers;
+    kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+    KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+    KDNN::ActivationLayerFWD layer(inputTensorInfo, outputTensorInfo, KDNN::ActivationFunction::SIGMOID);
+    layer.Run(src, dst);
+    KDNN::Threading::DeactivateThreadpool();
+    return;
+}
+
+template <typename T>
+inline void kdnnSoftmaxOp(OpKernelContext* ctx, const Tensor &input, Tensor *output)
+{
+    const T* src = input.flat_inner_dims<T>().data();
+    T* dst = output->flat_inner_dims<T>().data();
+    KDNN::Shape tensorShape({input.flat_inner_dims<T>().dimension(0), input.flat_inner_dims<T>().dimension(1)});
+    KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::AB);
+    KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::AB);
+
+    thread::ThreadPool* thread_pool = 
+        ctx->device()
+        ->tensorflow_cpu_worker_threads()
+        ->workers;
+    kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+    KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+    KDNN::SoftmaxLayerFWD layer(inputTensorInfo, outputTensorInfo, 1, KDNN::AlgorithmKind::SOFTMAX);
+    layer.Run(src, dst);
+    KDNN::Threading::DeactivateThreadpool();
+    return;
+}
+
+}// namespace tensorflow
+#endif
\ No newline at end of file
diff --git a/tensorflow/third_party/KDNN/kdnn_layout_adapter.h b/third_party/KDNN/kdnn_layout_adapter.h
new file mode 100644
index 0000000..8487722
--- /dev/null
+++ b/third_party/KDNN/kdnn_layout_adapter.h
@@ -0,0 +1,42 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+    
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H_
+#include "kdnn.hpp"
+
+namespace KDNN {
+
+template <int Rank, bool Transposed = false>
+struct LayoutAdapter {
+    static constexpr Layout value = Layout::UNDEFINED;
+};
+
+template <>
+struct LayoutAdapter<1, false> {
+    static constexpr Layout value = Layout::A;
+};
+
+template <>
+struct LayoutAdapter<2, false> {
+    static constexpr Layout value = Layout::AB;
+};
+
+template <>
+struct LayoutAdapter<2, true> {
+    static constexpr Layout value = Layout::BA;
+};
+} // KDNN
+#endif  // TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H
\ No newline at end of file
diff --git a/tensorflow/third_party/KDNN/kdnn_threadpool.h b/third_party/KDNN/kdnn_threadpool.h
new file mode 100644
index 0000000..0c76b4f
--- /dev/null
+++ b/third_party/KDNN/kdnn_threadpool.h
@@ -0,0 +1,72 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#define EIGEN_USE_THREADS
+
+#include "kdnn.hpp"
+#include "tensorflow/core/platform/blocking_counter.h"
+#include "tensorflow/core/platform/threadpool.h"
+
+namespace kdnn {
+
+using tensorflow::thread::ThreadPool;
+
+class KDNNThreadPool : public KDNN::Threading::ThreadpoolIface {
+ public:
+  KDNNThreadPool() = default;
+
+  KDNNThreadPool(ThreadPool* thread_pool,
+                int num_threads = -1)
+      : thread_pool_(thread_pool), 
+      eigen_interface_(thread_pool->AsEigenThreadPool()) {
+    set_num_and_max_threads(num_threads);
+  }
+  
+  int GetNumThreads() const override { return num_threads_; }
+
+  void ParallelFor(int n, int64_t cost_per_unit,
+                const std::function<void(int, int)>& fn) override {
+    thread_pool_->ParallelFor(n, cost_per_unit, fn);
+  }
+
+  bool IsInParallel() const override {
+    return eigen_interface_->CurrentThreadId() != -1;
+  }
+
+  ~KDNNThreadPool() {}
+
+ private:
+  ThreadPool* thread_pool_ = nullptr;
+  Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
+  int num_threads_ = 1;
+  inline void set_num_and_max_threads(int num_threads) {
+    num_threads_ =
+        num_threads == -1 ? eigen_interface_->NumThreads() : num_threads;
+  }
+};
+
+}  // namespace kdnn
+
+#endif  // TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
diff --git a/tensorflow/third_party/KDNN/kdnn_types_adapter.h b/third_party/KDNN/kdnn_types_adapter.h
new file mode 100644
index 0000000..751e00b
--- /dev/null
+++ b/third_party/KDNN/kdnn_types_adapter.h
@@ -0,0 +1,60 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
+#include "kdnn.hpp"
+
+namespace KDNN {
+namespace Element {
+
+template <typename T>
+struct TypeAdapter {
+    static constexpr TypeT value = TypeT::UNDEFINED;
+};
+
+template <>
+struct TypeAdapter<float> {
+    static constexpr TypeT value = TypeT::F32;
+};
+
+template <>
+struct TypeAdapter<Eigen::half> {
+    static constexpr TypeT value = TypeT::F16;
+};
+
+template <>
+struct TypeAdapter<tensorflow::bfloat16> {
+    static constexpr TypeT value = TypeT::BF16;
+};
+
+template <>
+struct TypeAdapter<int32_t> {
+    static constexpr TypeT value = TypeT::S32;
+};
+
+template <>
+struct TypeAdapter<int8_t> {
+    static constexpr TypeT value = TypeT::S8;
+};
+
+template <>
+struct TypeAdapter<uint8_t> {
+    static constexpr TypeT value = TypeT::U8;
+};
+} // Element
+} // KDNN
+
+#endif  // TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
\ No newline at end of file
diff --git a/tensorflow/third_party/KDNN/tensorflow_kdnn_include_adapter.patch b/third_party/KDNN/tensorflow_kdnn_include_adapter.patch
new file mode 100644
index 0000000..f6721c1
--- /dev/null
+++ b/third_party/KDNN/tensorflow_kdnn_include_adapter.patch
@@ -0,0 +1,137 @@
+diff -ruN include/service/kdnn_service.hpp include_bak/service/kdnn_service.hpp
+--- include/service/kdnn_service.hpp	2025-10-25 18:55:04.018830570 +0800
++++ include_bak/service/kdnn_service.hpp	2025-10-25 18:23:18.074830570 +0800
+@@ -107,7 +107,7 @@
+     T *allocate(SizeType n) const noexcept(false)
+     {
+         if (n > std::numeric_limits<SizeType>::max() / sizeof(T)) {
+-            throw BadArrayNewLength();
++            return nullptr;
+         }
+         return static_cast<T*>(AlignedAlloc(n * sizeof(T), alignment));
+     }
+diff -ruN include/types/kdnn_data_type.hpp include_bak/types/kdnn_data_type.hpp
+--- include/types/kdnn_data_type.hpp	2025-10-25 18:55:04.018830570 +0800
++++ include_bak/types/kdnn_data_type.hpp	2025-10-25 18:23:18.066830570 +0800
+@@ -65,9 +65,6 @@
+             }
+             default: {}
+         }
+-        if (type == TypeT::UNDEFINED) {
+-            throw Service::LogicError {"Type: unsupported data type"};
+-        }
+     }
+     SizeType GetSize() const noexcept
+     {
+diff -ruN include/types/kdnn_shape.hpp include_bak/types/kdnn_shape.hpp
+--- include/types/kdnn_shape.hpp	2025-10-25 18:55:04.018830570 +0800
++++ include_bak/types/kdnn_shape.hpp	2025-10-25 18:23:18.070830570 +0800
+@@ -36,7 +36,7 @@
+     Shape(T *ptr, const SizeType size) noexcept(false) : numDims(size)
+     {
+         if (ptr == nullptr) {
+-            throw Service::LogicError("Shape: ptr is nullptr");
++            return;
+         }
+         CheckNumDims(numDims);
+         for (SizeType i = 0; i < numDims; ++i) {
+@@ -83,7 +83,7 @@
+     Shape& ResetShape(T *ptr, const SizeType size) noexcept(false)
+     {
+         if (ptr == nullptr) {
+-            throw Service::LogicError("Shape: ptr is nullptr");
++            return *this;
+         }
+         CheckNumDims(size);
+         numDims = size;
+@@ -99,7 +99,7 @@
+     Shape& operator+=(const Shape &adder) noexcept(false)
+     {
+         if (adder.GetNumDims() !=  this->GetNumDims()) {
+-            throw Service::LogicError("Shape: different size of base and adder shapes");
++            return *this;
+         }
+         for (SizeType i = 0; i < adder.GetNumDims(); ++i) {
+             this->operator[](i) += adder[i];
+@@ -109,9 +109,6 @@
+ 
+     Shape operator+(const Shape &adder) const noexcept(false)
+     {
+-        if (adder.GetNumDims() != this->GetNumDims()) {
+-            throw Service::LogicError("Shape: different size of base and adder shapes");
+-        }
+         std::array<SizeType, NUM_MAX_DIMENSIONS> tmp;
+         for (SizeType i = 0; i < adder.GetNumDims(); ++i) {
+             tmp[i] = this->operator[](i) + adder[i];
+@@ -121,17 +118,11 @@
+ 
+     SizeType operator[](SizeType id) const noexcept(false)
+     {
+-        if (id >= numDims) {
+-            throw Service::LogicError("Shape: index >= num_dims");
+-        }
+         return dimsArray[id];
+     }
+ 
+     SizeType& operator[](SizeType id) noexcept(false)
+     {
+-        if (id >= numDims) {
+-            throw Service::LogicError("Shape: index >= num_dims");
+-        }
+         return dimsArray[id];
+     }
+ 
+@@ -142,9 +133,6 @@
+ 
+     SizeType GetTotalDimsSize() const noexcept(false)
+     {
+-        if (Service::WillIntMultOverflow(dimsArray.begin(), dimsArray.begin() + numDims)) {
+-            throw Service::LogicError("Shape: computing total size will cause overflow");
+-        }
+         SizeType accum = 1;
+         for (SizeType i = 0; i < numDims; ++i) {
+             accum *= dimsArray[i];
+@@ -153,12 +141,7 @@
+     }
+ 
+ private:
+-    void CheckNumDims(SizeType nDims) const noexcept(false)
+-    {
+-        if (nDims > NUM_MAX_DIMENSIONS) {
+-            throw Service::LogicError("Shape: dims is greater than NUM_MAX_DIMENSIONS");
+-        }
+-    }
++    void CheckNumDims(SizeType nDims) const noexcept(false) {}
+     std::array<SizeType, NUM_MAX_DIMENSIONS> dimsArray;
+     SizeType numDims;
+ };
+diff -ruN include/types/kdnn_tensor_info.hpp include_bak/types/kdnn_tensor_info.hpp
+--- include/types/kdnn_tensor_info.hpp	2025-10-25 18:55:04.018830570 +0800
++++ include_bak/types/kdnn_tensor_info.hpp	2025-10-25 18:23:18.070830570 +0800
+@@ -214,7 +214,7 @@
+                 return Layout::ABCDE;
+             }
+             default: {
+-                throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++                return Layout::UNDEFINED;
+             }
+         }
+     }
+@@ -238,7 +238,7 @@
+                 return Layout::ACDEB;
+             }
+             default: {
+-                throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++                return Layout::UNDEFINED;
+             }
+         }
+     }
+@@ -262,7 +262,7 @@
+                 return Layout::BCDEA;
+             }
+             default: {
+-                throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++                return Layout::UNDEFINED;
+             }
+         }
+     }
diff --git a/tensorflow/third_party/fused_embedding/BUILD b/third_party/fused_embedding/BUILD
new file mode 100644
index 0000000..a9e1634
--- /dev/null
+++ b/third_party/fused_embedding/BUILD
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+    name = "build_with_fused_embedding",
+    define_values = {
+        "build_with_fused_embedding": "true",
+    },
+)
+
+bzl_library(
+    name = "build_defs_bzl",
+    srcs = ["build_defs.bzl"],
+)
diff --git a/tensorflow/third_party/fused_embedding/build_defs.bzl b/third_party/fused_embedding/build_defs.bzl
new file mode 100644
index 0000000..f3509f5
--- /dev/null
+++ b/third_party/fused_embedding/build_defs.bzl
@@ -0,0 +1,8 @@
+"""Starlark macros for fused_embedding.
+"""
+
+def if_fused_embedding(if_true, if_false = []):
+    return select({
+        "@org_tensorflow//third_party/fused_embedding:build_with_fused_embedding": if_true,
+        "//conditions:default": if_false,
+    })
diff --git a/tensorflow/third_party/grpc/upb_gcc10_compile_fix.patch b/third_party/grpc/upb_gcc10_compile_fix.patch
new file mode 100644
index 0000000..dd38391
--- /dev/null
+++ b/third_party/grpc/upb_gcc10_compile_fix.patch
@@ -0,0 +1,11 @@
+--- a/upb/upb.c	2025-05-30 17:01:35.956845750 +0800
++++ b/upb/upb.c	2025-05-30 16:54:07.768845750 +0800
+@@ -37,7 +37,7 @@
+ void upb_status_seterrmsg(upb_status *status, const char *msg) {
+   if (!status) return;
+   status->ok = false;
+-  strncpy(status->msg, msg, sizeof(status->msg));
++  strncpy(status->msg, msg, sizeof(status->msg) - 1);
+   nullz(status);
+ }
+ 
diff --git a/tensorflow/third_party/kblas/BUILD b/third_party/kblas/BUILD
new file mode 100644
index 0000000..1a1c216
--- /dev/null
+++ b/third_party/kblas/BUILD
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+    name = "build_with_kblas",
+    define_values = {
+        "build_with_kblas": "true",
+    },
+)
+
+bzl_library(
+    name = "build_defs_bzl",
+    srcs = ["build_defs.bzl"],
+)
diff --git a/tensorflow/third_party/kblas/build_defs.bzl b/third_party/kblas/build_defs.bzl
new file mode 100644
index 0000000..77a211b
--- /dev/null
+++ b/third_party/kblas/build_defs.bzl
@@ -0,0 +1,8 @@
+"""Starlark macros for kblas.
+"""
+
+def if_kblas(if_true, if_false = []):
+    return select({
+        "@org_tensorflow//third_party/kblas:build_with_kblas": if_true,
+        "//conditions:default": if_false,
+    })
diff --git a/tensorflow/third_party/kblas/kblas.BUILD b/third_party/kblas/kblas.BUILD
new file mode 100644
index 0000000..04a6d51
--- /dev/null
+++ b/third_party/kblas/kblas.BUILD
@@ -0,0 +1,12 @@
+cc_import(
+    name = "kblas_so",
+    shared_library = "lib/sve/kblas/locking/libkblas.so",
+    visibility = ["//visibility:public"],
+)
+
+cc_library(
+    name = "kblas",
+    hdrs = ["include/kblas.h"],
+    includes = ["include"],
+    visibility = ["//visibility:public"],
+)
\ No newline at end of file
diff --git a/tensorflow/third_party/kblas/workspace.bzl b/third_party/kblas/workspace.bzl
new file mode 100644
index 0000000..c438aa0
--- /dev/null
+++ b/third_party/kblas/workspace.bzl
@@ -0,0 +1,10 @@
+"""Provides the repository macro to import kblas."""
+
+def repo():
+    """Imports kblas."""
+
+    native.new_local_repository(
+        name = "kblas_archive",
+        build_file = "@org_tensorflow//third_party/kblas:kblas.BUILD",
+        path = "/usr/local/kml",
+    )
diff --git a/tensorflow/third_party/ktfop/BUILD b/third_party/ktfop/BUILD
new file mode 100644
index 0000000..7aec8b6
--- /dev/null
+++ b/third_party/ktfop/BUILD
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+    # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+    default_visibility = ["//visibility:public"],
+    licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+    name = "build_with_ktfop",
+    define_values = {
+        "build_with_ktfop": "true",
+    },
+)
+
+bzl_library(
+    name = "build_defs_bzl",
+    srcs = ["build_defs.bzl"],
+)
diff --git a/tensorflow/third_party/ktfop/build_defs.bzl b/third_party/ktfop/build_defs.bzl
new file mode 100644
index 0000000..41cef1e
--- /dev/null
+++ b/third_party/ktfop/build_defs.bzl
@@ -0,0 +1,8 @@
+"""Starlark macros for ktfop.
+"""
+
+def if_ktfop(if_true, if_false = []):
+    return select({
+        "@org_tensorflow//third_party/ktfop:build_with_ktfop": if_true,
+        "//conditions:default": if_false,
+    })
diff --git a/tensorflow/third_party/ktfop/ktfop.BUILD b/third_party/ktfop/ktfop.BUILD
new file mode 100644
index 0000000..063365f
--- /dev/null
+++ b/third_party/ktfop/ktfop.BUILD
@@ -0,0 +1,14 @@
+cc_import(
+    name = "ktfop_so",
+    shared_library = "lib/sve/libktfop.so",
+    deps = ["@kblas_archive//:kblas_so"],
+)
+
+cc_library(
+    name = "ktfop",
+    hdrs = ["include/ktfop.h"],
+    includes = ["include"],
+    deps = [":ktfop_so",
+            "@kblas_archive//:kblas"],
+    visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/third_party/ktfop/workspace.bzl b/third_party/ktfop/workspace.bzl
new file mode 100644
index 0000000..76163ae
--- /dev/null
+++ b/third_party/ktfop/workspace.bzl
@@ -0,0 +1,10 @@
+"""Provides the repository macro to import ktfop."""
+
+def repo():
+    """Imports ktfop."""
+
+    native.new_local_repository(
+        name = "ktfop_archive",
+        build_file = "@org_tensorflow//third_party/ktfop:ktfop.BUILD",
+        path = "/usr/local/sra_inference",
+    )
diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/env.cc b/third_party/xla/third_party/tsl/tsl/platform/default/env.cc
index 62245de..7344b63 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/default/env.cc
+++ b/third_party/xla/third_party/tsl/tsl/platform/default/env.cc
@@ -44,6 +44,8 @@ limitations under the License.
 #include "tsl/platform/strcat.h"
 #include "tsl/protobuf/error_codes.pb.h"
 
+#define GPR_ARRAY_SIZE(array) (sizeof(array) / sizeof(*(array)))
+
 namespace tsl {
 
 namespace {
@@ -86,6 +88,16 @@ class PThread : public Thread {
   static void* ThreadFn(void* params_arg) {
     std::unique_ptr<ThreadParams> params(
         reinterpret_cast<ThreadParams*>(params_arg));
+    
+    if (!params->name.empty()) {
+      /* Linux supports 16 characters max, and will
+        * error if it's longer. */
+      char buf[16];
+      size_t buf_len = GPR_ARRAY_SIZE(buf) - 1;
+      strncpy(buf, params->name.c_str(), buf_len);
+      buf[buf_len] = '\0';
+      pthread_setname_np(pthread_self(), buf); 
+    }
     {
       mutex_lock l(name_mutex);
       GetThreadNameRegistry().emplace(std::this_thread::get_id(), params->name);
diff --git a/third_party/xla/third_party/tsl/tsl/platform/numa.h b/third_party/xla/third_party/tsl/tsl/platform/numa.h
index 997d03d..5ff53a1 100644
--- a/third_party/xla/third_party/tsl/tsl/platform/numa.h
+++ b/third_party/xla/third_party/tsl/tsl/platform/numa.h
@@ -32,6 +32,11 @@ int NUMANumNodes();
 
 static const int kNUMANoAffinity = -1;
 
+enum ThreadAffinity {
+  OFF,
+  ORDER,
+  INTERVAL
+};
 // If possible sets affinity of the current thread to the specified NUMA node.
 // If node == kNUMANoAffinity removes affinity to any particular node.
 void NUMASetThreadNodeAffinity(int node);
diff --git a/third_party/xla/xla/executable_run_options.cc b/third_party/xla/xla/executable_run_options.cc
index 795c5fc..aeafcf4 100644
--- a/third_party/xla/xla/executable_run_options.cc
+++ b/third_party/xla/xla/executable_run_options.cc
@@ -131,6 +131,13 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
 
 int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
 
+ExecutableRunOptions& ExecutableRunOptions::set_run_in_tf_kernel(bool run_in_tf_kernel) {
+  run_in_tf_kernel_ = run_in_tf_kernel;
+  return *this;
+}
+
+bool ExecutableRunOptions::run_in_tf_kernel() const { return run_in_tf_kernel_; }
+
 ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) {
   run_id_ = id;
   return *this;
diff --git a/third_party/xla/xla/executable_run_options.h b/third_party/xla/xla/executable_run_options.h
index 31ba23b..6fe4bd4 100644
--- a/third_party/xla/xla/executable_run_options.h
+++ b/third_party/xla/xla/executable_run_options.h
@@ -171,6 +171,10 @@ class ExecutableRunOptions {
   ExecutableRunOptions& set_rng_seed(int rng_seed);
   int rng_seed() const;
 
+  ExecutableRunOptions& set_run_in_tf_kernel(bool run_in_tf_kernel);
+
+  bool run_in_tf_kernel() const;
+
   ExecutableRunOptions& set_launch_id(int32_t launch_id) {
     launch_id_ = launch_id;
     return *this;
@@ -224,6 +228,7 @@ class ExecutableRunOptions {
   const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
   ExecutionProfile* execution_profile_ = nullptr;
   int rng_seed_ = 0;
+  bool run_in_tf_kernel_ = false;
   int32_t launch_id_ = 0;
   stream_executor::Stream* device_to_host_stream_ = nullptr;
   stream_executor::Stream* host_to_device_stream_ = nullptr;
diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc
index 9ff6d3b..3104e65 100644
--- a/third_party/xla/xla/service/cpu/cpu_executable.cc
+++ b/third_party/xla/xla/service/cpu/cpu_executable.cc
@@ -638,33 +638,38 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
       ExecutionOutput result,
       CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
                                absl::MakeSpan(arguments)));
-
-  // Logically we want this lambda to capture `buffers` by move, ultimately our
-  // functor needs to be wrapped in an std::function, and that requires its
-  // functor to be copyable.  Thus we perpetrate the hack of capturing buffers
-  // "by shared pointer".
-  //
-  // We also need to change the types of some of the variables we capture:
-  // run_options needs to change from a pointer to a value type, and arguments
-  // needs to change from a Span into a vector.  We use a struct instead
-  // of a lambda to make this explicit.
-  struct AsyncRunTask {
-    CpuExecutable* executable;
-    ServiceExecutableRunOptions run_options;
-    std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
-    HloExecutionProfile* hlo_execution_profile;
-
-    Status operator()() {
-      return executable->ExecuteComputeFunction(
-          &run_options.run_options(), *task_buffers, hlo_execution_profile);
-    }
-  };
-  host_stream->EnqueueTaskWithStatus(
-      AsyncRunTask{this, *run_options,
-                   std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
-                       std::move(buffers)),
-                   hlo_execution_profile});
-
+  if (run_options->run_options().run_in_tf_kernel()) {
+    std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers =
+        std::make_shared<std::vector<MaybeOwningDeviceMemory>>(std::move(buffers));
+    (void)this->ExecuteComputeFunction(
+        &run_options->run_options(), *task_buffers, hlo_execution_profile);
+  } else {
+    // Logically we want this lambda to capture `buffers` by move, ultimately our
+    // functor needs to be wrapped in an std::function, and that requires its
+    // functor to be copyable.  Thus we perpetrate the hack of capturing buffers
+    // "by shared pointer".
+    //
+    // We also need to change the types of some of the variables we capture:
+    // run_options needs to change from a pointer to a value type, and arguments
+    // needs to change from a Span into a vector.  We use a struct instead
+    // of a lambda to make this explicit.
+    struct AsyncRunTask {
+      CpuExecutable* executable;
+      ServiceExecutableRunOptions run_options;
+      std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
+      HloExecutionProfile* hlo_execution_profile;
+
+      Status operator()() {
+        return executable->ExecuteComputeFunction(
+            &run_options.run_options(), *task_buffers, hlo_execution_profile);
+      }
+    };
+    host_stream->EnqueueTaskWithStatus(
+        AsyncRunTask{this, *run_options,
+                    std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
+                        std::move(buffers)),
+                    hlo_execution_profile});
+  }
   MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
   return std::move(result);
 }
diff --git a/tensorflow/unit_test.sh b/unit_test.sh
new file mode 100644
index 0000000..b76aa98
--- /dev/null
+++ b/unit_test.sh
@@ -0,0 +1,9 @@
+export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
+export TF_NEED_CLANG=0
+export CC_OPT_FLAGS='-march=armv8.3-a+crc'
+
+export PYTHON_BIN_PATH=$(which python)
+yes "" | $PYTHON_BIN_PATH configure.py
+
+bazel --output_user_root=./test_output test --distdir=../serving/download --test_tag_filters=-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test --test_lang_filters=cc,java -k --test_timeout 300,450,1200,3600 --config=opt --test_output=errors --test_size_filters=small,medium,large --build_tests_only -- //tensorflow/core/... //tensorflow/compiler/jit/... -//tensorflow/core/tpu/...
\ No newline at end of file