@@ -223,6 +223,14 @@ build:mkl_aarch64 -c opt
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
build:mkl_aarch64_threadpool -c opt
+# Config setting to build ktfop.
+build:ktfop --define=build_with_ktfop=true
+build:ktfop --define=build_with_kblas=true
+build:ktfop -c opt
+
+# Config setting to build fused_embedding.
+build:fused_embedding --define=build_with_fused_embedding=true
+
# CUDA: This config refers to building CUDA op kernels with nvcc.
build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
@@ -19,6 +19,7 @@ tensorflow/contrib/cmake/_build/
/build/
[Bb]uild/
/build_output/
+/test_output
/tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
new file mode 100644
@@ -0,0 +1,130 @@
+#/bin/bash
+set -ex
+
+TENSORFLOW_DIR=""
+ENABLE_GCC12=false
+ENABLE_KDNN=false
+KDNN_OPTIONS=""
+
+usage() {
+ echo "Usage: $0 [--features <feature1,feature2>]"
+ echo "Example: $0 --features gcc12,kdnn"
+ echo "Notes: --features gcc12 is only suitable for openeuler 22.03 to set gcc12 insdead of default gcc10"
+ exit 1
+}
+
+if [ -f /etc/os-release ]; then
+ source /etc/os-release
+fi
+
+echo "current os: $NAME $VERSION_ID"
+
+case "$VERSION_ID" in
+ "22.03")
+ echo "config gcc12 path"
+ GCC12_PATH=/opt/openEuler/gcc-toolset-12/root/usr/bin/
+ GCC12_LD_LIBRARY_PATH=/opt/openEuler/gcc-toolset-12/root/usr/lib64
+ ;;
+ "24.03")
+ echo "use default gcc"
+ ;;
+ *)
+ echo "unsupported os version: $VERSION_ID"
+ exit 1
+ ;;
+esac
+
+# 定义目标软链接路径
+TARGET_LINK="/usr/bin/aarch64-linux-gnu-gcc"
+
+# 检查文件或链接是否已经存在
+if [ ! -f "$TARGET_LINK" ] && [ ! -L "$TARGET_LINK" ]; then
+ echo "未检测到 $TARGET_LINK,正在创建软链接..."
+
+ # 获取系统自带 gcc 的路径
+ GCC_PATH=$(which gcc)
+
+ if [ -n "$GCC_PATH" ]; then
+ # 使用 sudo 权限创建链接(如果是 root 用户可去掉 sudo)
+ ln -s "$GCC_PATH" "$TARGET_LINK"
+ echo "软链接创建成功: $TARGET_LINK -> $GCC_PATH"
+ else
+ echo "错误: 系统未安装 gcc,请先运行 yum install gcc"
+ exit 1
+ fi
+else
+ echo "检测到 $TARGET_LINK 已存在,跳过创建步骤。"
+fi
+
+while [[ "$#" -gt 0 ]]; do
+ case "$1" in
+ --features)
+ if [[ -z "$2" ]]; then
+ echo "Error: --features requires a value"
+ usage
+ fi
+ IFS=',' read -ra features_array <<< "$2"
+ for feature in "${features_array[@]}"; do
+ case "$feature" in
+ "gcc12")
+ ENABLE_GCC12=true
+ ;;
+ "kdnn")
+ ENABLE_KDNN=true
+ ;;
+ *)
+ echo "Warning: Unknown feature '$feature', ignoring"
+ ;;
+ esac
+ done
+ shift 2
+ ;;
+ -h|--help)
+ usage
+ ;;
+ *)
+ echo "Unknown parameter: $1"
+ usage
+ ;;
+ esac
+done
+
+TENSORFLOW_ROOT=$(pwd)
+DIST_DIR=$TENSORFLOW_ROOT/download
+PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin
+
+export PATH=$BAZEL_PATH:$PATH
+DIST_DIR="${DISTDIR:-$DIST_DIR}"
+BAZEL_COMPILE_CACHE="${BUILD_CACHE_DIR:-$TENSORFLOW_ROOT/output}"
+
+if ! command -v bazel &> /dev/null; then
+ echo "Error: Bazel is not installed. Please install Bazel and try again."
+ exit 1
+fi
+
+bazel version
+
+if [ "$ENABLE_GCC12" == true ]; then
+ export PATH=$GCC12_PATH:$PATH
+ export LD_LIBRARY_PATH=$GCC12_LD_LIBRARY_PATH
+ GCC_VERSION=$(gcc -dumpversion | cut -d. -f1)
+ if [[ "$GCC_VERSION" != "12" ]]; then
+ echo "Error: GCC version is $GCC_VERSION. Please install GCC 12. Consider use command: yum install gcc-toolset-12-gcc*"
+ exit 1
+ fi
+fi
+
+if [ "$ENABLE_KDNN" == true ]; then
+ KDNN_OPTIONS="--define=enable_kdnn=true"
+fi
+
+gcc --version
+cd $TENSORFLOW_ROOT && \
+PATH=$PATH \
+LD_LIBRARY_PATH=$LD_LIBRARY_PATH \
+bazel --output_user_root=$BAZEL_COMPILE_CACHE build --distdir=$DIST_DIR \
+--host_copt=-march=armv8.3-a --copt=-march=armv8.3-a --define with_default_optimizations=true \
+--copt=-Wno-sign-compare --config=v2 --config=noaws \
+$KDNN_OPTIONS \
+//tensorflow/tools/pip_package:build_pip_package
+./bazel-bin/tensorflow/tools/pip_package/build_pip_package ./output-release
@@ -327,6 +327,7 @@ StatusOr<xla::ExecutionOutput> RunExecutable(
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
+ run_options.set_run_in_tf_kernel(ctx->executor_policy() == ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR);
StatusOr<xla::ExecutionOutput> execution_output;
bool run_synchronous =
@@ -111,6 +111,14 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
+load(
+ "//third_party/ktfop:build_defs.bzl",
+ "if_ktfop",
+)
+load(
+ "//third_party/fused_embedding:build_defs.bzl",
+ "if_fused_embedding",
+)
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
package(
@@ -650,6 +658,11 @@ cc_library(
"//tensorflow/core/kernels/mkl:mkl_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops",
"//tensorflow/core/kernels/mkl:mkl_deprecated_ops",
+ ]) + if_ktfop([
+ "//tensorflow/core/kernels/ktfop:fused_embedding_ops",
+ "//tensorflow/core/kernels/ktfop:softmax_ops",
+ ]) + if_fused_embedding([
+ "//tensorflow/core/kernels/fused_embedding:fused_embedding_ops",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:cudnn_rnn_kernels",
]) + if_cuda([
@@ -638,6 +638,18 @@ cc_library(
],
)
+cc_library(
+ name = "kernel_stat",
+ hdrs = ["kernel_stat.h"],
+ copts = tf_copts(),
+ deps = [
+ ":graph_view",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_library(
name = "executor",
srcs = ["executor.cc"],
@@ -649,6 +661,7 @@ cc_library(
":entry",
":executor_factory",
":graph_view",
+ ":kernel_stat",
":immutable_executor_state",
":local_executor_params",
":pending_counts",
@@ -327,6 +327,12 @@ DirectSession::DirectSession(const SessionOptions& options,
factory_(factory),
cancellation_manager_(new CancellationManager()),
operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
+
+ if (options_.config.use_batch_op_scheduling()) {
+ use_batch_scheduling_executor_ = true;
+ LOG(INFO) << "enable batch scheduling executor";
+ }
+
const int thread_pool_size =
options_.config.session_inter_op_thread_pool_size();
if (thread_pool_size > 0) {
@@ -679,7 +685,11 @@ Status DirectSession::RunInternal(
args.run_all_kernels_inline = pool == nullptr;
args.start_time_usecs = start_time_usecs;
args.deadline = deadline;
-
+ if (use_batch_scheduling_executor_) {
+ args.executor_policy = ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR;
+ } else {
+ args.executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
+ }
const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
bool update_cost_model = false;
@@ -437,6 +437,7 @@ class DirectSession : public Session {
// Otherwise run in global thread pool, session owned thread pool or handler
// pool according to other specifications of RunOptions and ConfigProto.
bool run_in_caller_thread_ = false;
+ bool use_batch_scheduling_executor_ = false;
DirectSession(const DirectSession&) = delete;
void operator=(const DirectSession&) = delete;
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
+#include "tensorflow/core/common_runtime/kernel_stat.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
@@ -155,76 +156,21 @@ class ExecutorImpl : public Executor {
return OkStatus();
}
+ ImmutableExecutorState& GetImmutableState() {
+ return immutable_state_;
+ }
+
+ ExecutorInternal::KernelStats* GetKernelStat() {
+ return &kernel_stats_;
+ }
private:
void RunAsyncInternal(const Args& args, DoneCallback done) override;
template <class PropagatorStateType>
friend class ExecutorState;
- // Stores execution time information about the kernels in an executor's graph.
- class KernelStats {
- public:
- KernelStats() = default;
-
- void Initialize(const GraphView& gview) {
- is_expensive_.resize(gview.num_nodes());
- cost_estimates_ =
- std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
- for (int32_t i = 0; i < gview.num_nodes(); ++i) {
- if (gview.node(i)) {
- is_expensive_[i] =
- gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
- cost_estimates_[i] = kInitialCostEstimateCycles;
- }
- }
- }
-
- // Returns true iff the given node is considered "expensive". The
- // executor uses this flag to optimize graph execution, for example
- // by "inlining" inexpensive kernels.
- bool IsExpensive(const NodeItem& node) const {
- return is_expensive_[node.node_id] &&
- (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
- kOpIsExpensiveThresholdCycles);
- }
-
- // Returns the value of kernel->IsExpensive().
- bool HasExpensiveMarker(const NodeItem& node) const {
- return is_expensive_[node.node_id];
- }
-
- // Updates the dynamic cost estimate, which is used to determine whether the
- // given node is expensive. The new cost estimate is a weighted average of
- // the old cost estimate and the latest cost. We only update cost estimates
- // for kernels for which IsExpensive() return true.
- void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
- // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
- // updates may result in one or more updates being ignored. This does not
- // affect correctness but may slow down the update frequency.
- std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
- auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
-
- uint64 new_estimate =
- ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
-
- cost_estimate.store(new_estimate, std::memory_order_relaxed);
- }
-
- private:
- // Initial time (in CPU cycles) we expect an operation to take. Used to
- // determine whether an operation should be place in a threadpool.
- // Operations start out "expensive".
- static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
- static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
- static constexpr uint64 kCostDecay = 10;
-
- std::vector<bool> is_expensive_;
- // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
- std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
- };
-
ImmutableExecutorState immutable_state_;
- KernelStats kernel_stats_;
+ ExecutorInternal::KernelStats kernel_stats_;
ExecutorImpl(const ExecutorImpl&) = delete;
void operator=(const ExecutorImpl&) = delete;
@@ -284,12 +230,12 @@ class ExecutorState {
public:
ExecutorState(const Executor::Args& args,
const ImmutableExecutorState& immutable_state_,
- ExecutorImpl::KernelStats* kernel_stats_);
+ ExecutorInternal::KernelStats* kernel_stats_);
~ExecutorState();
void RunAsync(Executor::DoneCallback done);
- private:
+ protected:
// Use `TaggedNode` types defined by `PropagatorStateType`.
typedef typename PropagatorStateType::TaggedNode TaggedNode;
typedef
@@ -338,7 +284,7 @@ class ExecutorState {
// This method will clear `*ready` before returning.
//
// REQUIRES: `!ready->empty()`.
- void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
+ virtual void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
// A wrapper for runner_ to keep track of the pending queue length. Op
// execution should dispatch work using this function instead of using runner_
@@ -388,7 +334,7 @@ class ExecutorState {
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
CallFrameInterface* call_frame_;
const ImmutableExecutorState& immutable_state_;
- ExecutorImpl::KernelStats* const kernel_stats_;
+ ExecutorInternal::KernelStats* const kernel_stats_;
CancellationManager* cancellation_manager_;
tsl::CoordinationServiceAgent* coordination_service_agent_;
absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt;
@@ -397,6 +343,7 @@ class ExecutorState {
Executor::Args::Runner runner_;
bool sync_on_finish_;
const bool run_all_kernels_inline_;
+ ExecutorPolicy executor_policy_ = ExecutorPolicy::USE_NORMAL_EXECUTOR;
PropagatorStateType propagator_;
@@ -418,7 +365,7 @@ class ExecutorState {
template <class PropagatorStateType>
ExecutorState<PropagatorStateType>::ExecutorState(
const Executor::Args& args, const ImmutableExecutorState& immutable_state,
- ExecutorImpl::KernelStats* kernel_stats)
+ ExecutorInternal::KernelStats* kernel_stats)
: vlog_(VLOG_IS_ON(1)),
log_memory_(LogMemory::IsEnabled()),
step_id_(args.step_id),
@@ -446,6 +393,7 @@ ExecutorState<PropagatorStateType>::ExecutorState(
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
run_all_kernels_inline_(args.run_all_kernels_inline),
+ executor_policy_(args.executor_policy),
propagator_(immutable_state, step_id_, vlog_),
num_outstanding_ops_(0) {
if (args.user_intra_op_threadpool != nullptr) {
@@ -463,6 +411,42 @@ ExecutorState<PropagatorStateType>::~ExecutorState() {
delete slice_reader_cache_;
}
+template <class PropagatorStateType>
+class BatchSchedulingExecutorState : public ExecutorState<PropagatorStateType> {
+ public:
+ BatchSchedulingExecutorState(
+ const Executor::Args& args, const ImmutableExecutorState& immutable_state_,
+ ExecutorInternal::KernelStats* kernel_stats_)
+ : ExecutorState<PropagatorStateType>(args, immutable_state_, kernel_stats_) {}
+ ~BatchSchedulingExecutorState() {}
+
+ protected:
+ typedef typename PropagatorStateType::TaggedNode TaggedNode;
+ typedef
+ typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
+ typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
+
+ virtual void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
+};
+
+class ExecutorStateFactory {
+ public:
+ template <class PropagatorStateType>
+ static ExecutorState<PropagatorStateType>* Create(
+ const Executor::Args& args, ExecutorImpl* impl) {
+ ImmutableExecutorState& immutable_state = impl->GetImmutableState();
+ ExecutorInternal::KernelStats* kernel_stats = impl->GetKernelStat();
+
+ if (args.executor_policy == ExecutorPolicy::USE_BATCH_SCHEDULING_EXECUTOR) {
+ return new BatchSchedulingExecutorState<PropagatorStateType>(
+ args, immutable_state, kernel_stats);
+ } else {
+ return new ExecutorState<PropagatorStateType>(
+ args, immutable_state, kernel_stats);
+ }
+ }
+};
+
template <class PropagatorStateType>
template <typename Closure>
void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) {
@@ -511,7 +495,7 @@ void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
} else {
done_cb_ = std::move(done);
// Schedule to run all the ready ops in thread pool.
- ScheduleReady(&ready, nullptr);
+ this->ScheduleReady(&ready, nullptr);
}
}
@@ -730,7 +714,7 @@ void ExecutorState<PropagatorStateType>::Process(const TaggedNode& tagged_node,
profiler::TraceMeLevel::kVerbose);
TaggedNodeReadyQueue inline_ready;
inline_ready.push_back(tagged_node);
- return ProcessInline(&inline_ready, scheduled_nsec);
+ return this->ProcessInline(&inline_ready, scheduled_nsec);
}
template <class PropagatorStateType>
@@ -773,6 +757,7 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
params->slice_reader_cache = slice_reader_cache_;
params->runner = &runner_;
params->run_all_kernels_inline = run_all_kernels_inline_;
+ params->executor_policy = executor_policy_;
params->stats_collector = stats_collector_;
params->inc_num_deferred_ops_function = [this]() {
mutex_lock lock(num_deferred_ops_mu_);
@@ -886,13 +871,13 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
if (tagged_node.get_is_dead() && !item.is_transfer_node) {
if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
} else if (TF_PREDICT_FALSE(item.is_noop)) {
- ProcessNoop(stats);
+ this->ProcessNoop(stats);
} else if (item.const_tensor != nullptr && !params->track_allocations) {
- ProcessConstTensor(item, &outputs, stats);
+ this->ProcessConstTensor(item, &outputs, stats);
} else {
// Prepares inputs.
bool is_input_dead = false;
- s = PrepareInputs(item, first_input, inputs.get(), &input_alloc_attrs,
+ s = this->PrepareInputs(item, first_input, inputs.get(), &input_alloc_attrs,
&is_input_dead);
if (!s.ok()) {
// Clear inputs.
@@ -903,7 +888,7 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
propagator_.MaybeMarkCompleted(tagged_node);
activity_watcher::ActivityEnd(activity_id);
// Continue to process the nodes in 'inline_ready'.
- completed = NodeDone(s, ready.get(), stats, inline_ready);
+ completed = this->NodeDone(s, ready.get(), stats, inline_ready);
continue;
}
@@ -918,11 +903,11 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
params->input_alloc_attrs = input_alloc_attrs;
if (item.kernel_is_async) {
- ProcessAsync(item, *params, tagged_node, first_input, stats,
+ this->ProcessAsync(item, *params, tagged_node, first_input, stats,
activity_id);
launched_asynchronously = true;
} else {
- s = ProcessSync(item, params.get(), &outputs, stats);
+ s = this->ProcessSync(item, params.get(), &outputs, stats);
}
}
@@ -957,12 +942,12 @@ void ExecutorState<PropagatorStateType>::ProcessInline(
scheduled_nsec = nodestats::NowInNsec();
}
// Postprocess.
- completed = NodeDone(s, ready.get(), stats, inline_ready);
+ completed = this->NodeDone(s, ready.get(), stats, inline_ready);
}
} // while !inline_ready.empty()
// This thread of computation is done if completed = true.
- if (completed) ScheduleFinish();
+ if (completed) this->ScheduleFinish();
}
template <class PropagatorStateType>
@@ -1202,7 +1187,7 @@ bool ExecutorState<PropagatorStateType>::NodeDone(
}
// Schedule the ready nodes in 'ready'.
- ScheduleReady(ready, inline_ready);
+ this->ScheduleReady(ready, inline_ready);
return false;
}
@@ -1388,7 +1373,7 @@ void ExecutorState<PropagatorStateType>::ScheduleFinish() {
// Finish is always called exactly once per ExecutorState, either here if
// there aren't any deferred ops, or in the dec_num_deferred_ops_function if
// there are deferred ops.
- Finish();
+ this->Finish();
}
template <class PropagatorStateType>
@@ -1507,17 +1492,44 @@ void ExecutorState<PropagatorStateType>::Finish() {
}
}
+template <class PropagatorStateType>
+void BatchSchedulingExecutorState<PropagatorStateType>::ScheduleReady(
+ TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
+ DCHECK(!ready->empty());
+
+ int64_t scheduled_nsec = 0;
+ if (this->stats_collector_) {
+ scheduled_nsec = nodestats::NowInNsec();
+ }
+
+ if (inline_ready == nullptr) {
+ // Schedule all ready kernels from a single closure. This ensure that,
+ // regardless of the `runner_` implementation, all kernels will run
+ // sequentially on the same thread, and thread wakeup overhead and
+ // executor mutex contention will be minimized.
+ this->RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
+ for (auto& tagged_node : ready) {
+ this->Process(tagged_node, scheduled_nsec);
+ }
+ });
+ } else {
+ for (auto& tagged_node : *ready) {
+ inline_ready->push_back(tagged_node);
+ }
+ }
+
+ ready->clear();
+}
+
void ExecutorImpl::RunAsyncInternal(const Args& args, DoneCallback done) {
if (OpOrderDeterminismRequired()) {
- (new ExecutorState<OrderedPropagatorState>(args, immutable_state_,
- &kernel_stats_))
+ (ExecutorStateFactory::Create<OrderedPropagatorState>(args, this))
->RunAsync(std::move(done));
} else if (immutable_state_.requires_control_flow_support()) {
- (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
+ (ExecutorStateFactory::Create<PropagatorState>(args, this))
->RunAsync(std::move(done));
} else {
- (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
- &kernel_stats_))
+ (ExecutorStateFactory::Create<SimplePropagatorState>(args, this))
->RunAsync(std::move(done));
}
}
@@ -124,6 +124,7 @@ class Executor {
// If true, all kernels will be treated as "inexpensive", and hence executed
// on the scheduling thread.
bool run_all_kernels_inline = false;
+ ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
};
typedef std::function<void(const Status&)> DoneCallback;
new file mode 100644
@@ -0,0 +1,104 @@
+/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
+
+#include <atomic>
+#include <memory>
+#include <queue>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/time/time.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/common_runtime/graph_view.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/env_var.h"
+
+namespace tensorflow {
+namespace ExecutorInternal {
+
+// Stores execution time information about the kernels in an executor's graph.
+class KernelStats {
+ public:
+ KernelStats() = default;
+
+ void Initialize(const GraphView& gview) {
+ is_expensive_.resize(gview.num_nodes());
+ cost_estimates_ =
+ std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
+ for (int32_t i = 0; i < gview.num_nodes(); ++i) {
+ if (gview.node(i)) {
+ is_expensive_[i] =
+ gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
+ cost_estimates_[i] = kInitialCostEstimateCycles;
+ }
+ }
+ }
+
+ // Returns true iff the given node is considered "expensive". The
+ // executor uses this flag to optimize graph execution, for example
+ // by "inlining" inexpensive kernels.
+ bool IsExpensive(const NodeItem& node) const {
+ return is_expensive_[node.node_id] &&
+ (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
+ kOpIsExpensiveThresholdCycles);
+ }
+
+ // Returns the value of kernel->IsExpensive().
+ bool HasExpensiveMarker(const NodeItem& node) const {
+ return is_expensive_[node.node_id];
+ }
+
+ // Updates the dynamic cost estimate, which is used to determine whether the
+ // given node is expensive. The new cost estimate is a weighted average of
+ // the old cost estimate and the latest cost. We only update cost estimates
+ // for kernels for which IsExpensive() return true.
+ void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
+ // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
+ // updates may result in one or more updates being ignored. This does not
+ // affect correctness but may slow down the update frequency.
+ std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
+ auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
+
+ uint64 new_estimate =
+ ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
+
+ cost_estimate.store(new_estimate, std::memory_order_relaxed);
+ }
+
+ private:
+ // Initial time (in CPU cycles) we expect an operation to take. Used to
+ // determine whether an operation should be place in a threadpool.
+ // Operations start out "expensive".
+ static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
+ static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
+ static constexpr uint64 kCostDecay = 10;
+
+ std::vector<bool> is_expensive_;
+ // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
+ std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
+};
+} // namespace ExecutorInternal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_STAT_H_
\ No newline at end of file
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_FLAGS_H_
#include "absl/flags/declare.h"
+#include "absl/flags/flag.h"
ABSL_DECLARE_FLAG(bool, next_pluggable_device_use_c_api);
@@ -672,6 +672,7 @@ class OpKernelContext {
StepStatsCollectorInterface* stats_collector = nullptr;
GraphCollector* graph_collector = nullptr;
bool run_all_kernels_inline = false;
+ ExecutorPolicy executor_policy = ExecutorPolicy::USE_NORMAL_EXECUTOR;
const std::string* executor_type = nullptr;
// TensorSliceReaderCache support.
@@ -833,6 +834,9 @@ class OpKernelContext {
return params_->run_all_kernels_inline;
}
+ ExecutorPolicy executor_policy() const {
+ return params_->executor_policy;
+ }
// Returns the registered name for the executor type that is executing the
// current kernel. If empty, the default executor is used.
const std::string& executor_type() const;
@@ -38,6 +38,15 @@ load(
"if_mkl",
"mkl_deps",
)
+load(
+ "//third_party/KDNN:build_defs.bzl",
+ "if_enable_kdnn",
+ "kdnn_deps",
+)
+load(
+ "//third_party/ktfop:build_defs.bzl",
+ "if_ktfop",
+)
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm",
@@ -855,7 +864,7 @@ ARRAY_DEPS = [
"//tensorflow/core/framework:bounds_check",
"//tensorflow/core/profiler/lib:scoped_memory_debug_annotation",
"@eigen_archive//:eigen3",
-]
+] + kdnn_deps()
tf_kernel_library(
name = "immutable_constant_op",
@@ -3488,7 +3497,7 @@ tf_kernel_library(
"//tensorflow/core/protobuf:autotuning_proto_cc",
"//tensorflow/core/util/autotune_maps:conv_parameters",
"//tensorflow/core/util/proto:proto_utils",
- ]),
+ ]) + kdnn_deps(),
)
cc_library(
@@ -3568,7 +3577,7 @@ tf_kernel_library(
prefix = "cwise_op",
deps = MATH_DEPS + [
"//tensorflow/core/kernels/mlir_generated:cwise_op",
- ],
+ ] + kdnn_deps(),
)
tf_kernel_library(
@@ -4296,7 +4305,7 @@ tf_kernel_library(
]) + [
":gpu_prim_hdrs",
":loose_headers",
- ],
+ ] + kdnn_deps(),
)
tf_kernel_library(
@@ -5136,7 +5145,7 @@ tf_kernel_library(
"//tensorflow/core/framework:bounds_check",
"//tensorflow/core/util:determinism_for_kernels",
"@eigen_archive//:eigen3",
- ],
+ ] + kdnn_deps(),
)
tf_kernel_library(
@@ -31,6 +31,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
namespace tensorflow {
@@ -172,6 +177,13 @@ class ConcatBaseOp : public OpKernel {
return;
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if defined(ENABLE_KDNN)
+ KDNN::Element::TypeT kdnnType = KDNN::Element::TypeAdapter<T>::value;
+ if(IsKDNNEnabled() && kdnnType != KDNN::Element::TypeT::UNDEFINED) {
+ KDNNConcatImpl<T>(c, inputs_flat, &output_flat);
+ return;
+ }
+#endif
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
}
@@ -35,6 +35,10 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/bcast.h"
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -109,6 +113,15 @@ class BinaryOp : public BinaryOpShared {
Tensor* out;
OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
{0, 1}, 0, input_0.shape(), &out));
+#if defined(ENABLE_KDNN)
+ if constexpr (std::is_same<Functor, functor::safe_floor_mod<int64_t>>::value ||
+ std::is_same<Functor, functor::floor_fmod<float>>::value) {
+ if (IsKDNNEnabled()) {
+ kdnnFloormodOp<Functor>(ctx, input_0, input_1, out);
+ return;
+ }
+ }
+#endif
functor::BinaryFunctor<Device, Functor, 1>()(
eigen_device, out->template flat<Tout>(),
input_0.template flat<Tin>(), input_1.template flat<Tin>(),
@@ -322,6 +335,16 @@ class UnaryOp : public OpKernel {
} else {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
}
+
+#if defined(ENABLE_KDNN)
+ if constexpr (std::is_same<Functor, functor::sigmoid<float>>::value) {
+ if (IsKDNNEnabled() & inp.NumElements() > 0) {
+ kdnnSigmoidOp<Functor>(ctx, inp, out);
+ return;
+ }
+ }
+#endif
+
functor::UnaryFunctor<Device, Functor>()(
ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
}
new file mode 100644
@@ -0,0 +1,45 @@
+load("//tensorflow:tensorflow.default.bzl", "tf_kernel_library")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+DYNAMIC_DEPS = [
+ "//tensorflow/core/framework:bounds_check",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+ name = "fused_embedding_ops",
+ srcs = [
+ "embedding_lookup_hash_op.cc",
+ "lookup_embedding_by_hash.h",
+ ],
+ deps = [
+ "@eigen_archive//:eigen3",
+ "@farmhash_archive//:farmhash",
+ ] + DYNAMIC_DEPS
+)
+
+tf_cc_test(
+ name = "fused_embedding_ops_test",
+ srcs = ["fused_embedding_with_hash_bucket_ops_test.cc"],
+ deps = [
+ ":fused_embedding_ops",
+ "//tensorflow/core/ops:fused_embedding_ops_op_lib",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,92 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_var.h"
+#include "tensorflow/core/framework/bounds_check.h"
+
+#include "lookup_embedding_by_hash.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+struct LookupEmbeddingByHashFunctor {
+ int operator()(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table,
+ int64_t embedding_size, int32_t embedding_dims, float *output) {
+ return KPLookupEmbeddingByHashImpl::Compute(lookup_embedding, lookup_length, batch_size, embedding_table,
+ embedding_size, embedding_dims, output);
+ }
+};
+
+class KPLookupEmbeddingByHashOp : public OpKernel {
+ public:
+ explicit KPLookupEmbeddingByHashOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+ node_name = context->def().name();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ float *weight;
+ const Tensor& input_tensor = context->input(0);
+ const Tensor* weight_tensor = &context->input(1);
+
+ if (weight_tensor->dtype() == DT_RESOURCE) {
+ Var* variable;
+ OP_REQUIRES_OK(context,
+ LookupResource(context, HandleFromInput(context, 1),
+ &variable));
+ core::ScopedUnref s(variable);
+ weight_tensor = variable->tensor();
+ OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+ errors::InvalidArgument("Expect float weight in ",
+ node_name));
+ }
+
+ auto input = input_tensor.flat<tstring>();
+ weight = (float *)weight_tensor->tensor_data().data();
+ int64_t batch = input_tensor.dim_size(0);
+ int64_t embedding_dims = weight_tensor->dim_size(1);
+ uintptr_t cstr_addresses[batch];
+ size_t cstr_length[batch];
+ for (int i = 0; i < batch; ++i) {
+ cstr_addresses[i] = reinterpret_cast<uintptr_t>(input(i).c_str());
+ cstr_length[i] = input(i).length();
+ }
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({batch, embedding_dims}),
+ &output_tensor));
+ float *output = (float *)output_tensor->tensor_data().data();
+ LookupEmbeddingByHashFunctor lookfunctor;
+ int result = lookfunctor(cstr_addresses, cstr_length, batch,
+ weight, num_buckets_, embedding_dims, output);
+ OP_REQUIRES(context, (result == 0),
+ errors::InvalidArgument("Invalid argument, error code: ", result));
+ }
+
+ private:
+ int64_t num_buckets_;
+ std::string node_name;
+};
+REGISTER_KERNEL_BUILDER(Name("KPLookupEmbeddingByHash").Device(DEVICE_CPU), KPLookupEmbeddingByHashOp);
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,59 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class KPLookupEmbeddingByHashOpTest : public OpsTestBase {};
+
+TEST_F(KPLookupEmbeddingByHashOpTest, WithHashBucket) {
+ TF_EXPECT_OK(
+ NodeDefBuilder("fused_embedding_with_hash_bucket_op", "KPLookupEmbeddingByHash")
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("num_buckets", 5)
+ .Attr("combiner", 1)
+ .Attr("T_weight", DT_FLOAT)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ AddInputFromArray<tstring>(TensorShape({2, 1}),
+ {"ktfop", "fused_embedding_with_hash_bucket"});
+ AddInputFromArray<float>(TensorShape({5, 10}), {3.21, 7.89, 1.45, 9.32, 0.67, 5.43, 2.98, 8.76, 4.12, 6.54,
+ 0.23, 9.87, 3.56, 7.01, 2.34, 8.09, 5.67, 1.89, 6.78, 4.45,
+ 0.98, 9.01, 3.45, 7.23, 2.67, 8.34, 5.89, 1.23, 6.45, 4.78,
+ 0.56, 9.45, 3.78, 7.56, 2.12, 8.67, 5.34, 1.67, 6.23, 4.89,
+ 0.34, 9.78, 3.12, 7.45, 2.89, 8.23, 5.78, 1.45, 6.67, 4.23});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 10}));
+ test::FillValues<float>(
+ &expected, {3.21, 7.89, 1.45, 9.32, 0.67, 5.43, 2.98, 8.76, 4.12, 6.54,
+ 0.98, 9.01, 3.45, 7.23, 2.67, 8.34, 5.89, 1.23, 6.45, 4.78});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 0.0);
+}
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,82 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <farmhash.h>
+#include <cstring>
+
+namespace tensorflow {
+
+enum ReturnCode {
+ OK = 0,
+ NULL_POINTER = 1,
+ INVALID_PARAMETER = 2
+};
+
+static inline int Lookup1D(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table,
+ int64_t embedding_size, int32_t embedding_dims, float *output)
+{
+ uint64_t embedding_size_u64 = static_cast<uint64_t>(embedding_size);
+ for (int64_t i = 0; i < batch_size; ++i) {
+ if (lookup_length[i] != 0) {
+ uint64_t hash_value = ::util::Fingerprint64((char *)(lookup_embedding[i]), lookup_length[i]);
+ uint64_t x = hash_value % embedding_size_u64;
+ output[i] = embedding_table[x];
+ } else {
+ output[i] = 0;
+ }
+ }
+ return OK;
+}
+
+static inline int RegularLookup(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table,
+ int64_t embedding_size, int32_t embedding_dims, float *output)
+{
+ uint64_t embedding_dims_u64 = static_cast<uint64_t>(embedding_dims);
+ uint64_t embedding_size_u64 = static_cast<uint64_t>(embedding_size);
+ for (int64_t i = 0; i < batch_size; ++i) {
+ if (lookup_length[i] != 0) {
+ uint64_t hash_value = ::util::Fingerprint64((char *)(lookup_embedding[i]), lookup_length[i]);
+
+ uint64_t x = hash_value % embedding_size_u64;
+ for (uint64_t j = 0; j < embedding_dims; ++j) {
+ output[j] = embedding_table[x * embedding_dims + j];
+ }
+ output += embedding_dims_u64;
+ }
+ }
+ return OK;
+}
+
+struct KPLookupEmbeddingByHashImpl {
+ static int Compute(uintptr_t *lookup_embedding, size_t *lookup_length, int32_t batch_size, float *embedding_table,
+ int64_t embedding_size, int32_t embedding_dims, float *output) {
+ if (output == nullptr || lookup_embedding == nullptr || embedding_table == nullptr || lookup_length == nullptr) {
+ return NULL_POINTER;
+ }
+ if (batch_size < 0 || embedding_size <= 0 || embedding_dims <= 0) {
+ return INVALID_PARAMETER;
+ }
+
+ if (embedding_dims == 1) {
+ return Lookup1D(lookup_embedding, lookup_length, batch_size, embedding_table,
+ embedding_size, embedding_dims, output);
+ } else {
+ return RegularLookup(lookup_embedding, lookup_length, batch_size, embedding_table,
+ embedding_size, embedding_dims, output);
+ }
+ }
+};
+
+} // namespace tensorflow
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,36 @@
+load("//tensorflow:tensorflow.default.bzl", "tf_kernel_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//tensorflow:internal"],
+)
+
+DYNAMIC_DEPS = [
+ "//tensorflow/core/framework:bounds_check",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+]
+
+tf_kernel_library(
+ name = "fused_embedding_ops",
+ srcs = [
+ "embedding_lookup_op.cc",
+ ],
+ deps = [
+ "@eigen_archive//:eigen3",
+ "@ktfop_archive//:ktfop",
+ ] + DYNAMIC_DEPS
+)
+
+tf_kernel_library(
+ name = "softmax_ops",
+ srcs = [
+ "softmax.cc",
+ ],
+ deps = [
+ "@eigen_archive//:eigen3",
+ "@ktfop_archive//:ktfop",
+ ] + DYNAMIC_DEPS
+)
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,164 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_var.h"
+#include "tensorflow/core/framework/bounds_check.h"
+
+#include "ktfop.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+class KPFusedEmbeddingOp : public OpKernel {
+public:
+ explicit KPFusedEmbeddingOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_));
+ node_name = context->def().name();
+ }
+
+ ~KPFusedEmbeddingOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ float *weight;
+ const Tensor* weight_tensor = &context->input(0);
+
+ if (weight_tensor->dtype() == DT_RESOURCE) {
+ Var* variable;
+ OP_REQUIRES_OK(context,
+ LookupResource(context, HandleFromInput(context, 0),
+ &variable));
+ core::ScopedUnref s(variable);
+ weight_tensor = variable->tensor();
+ OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+ errors::InvalidArgument("Expect float weight in ",
+ node_name));
+ }
+
+ weight = (float *)weight_tensor->tensor_data().data();
+
+ const Tensor& input_tensor = context->input(1);
+ int64 *input = (int64 *)input_tensor.tensor_data().data();
+ const Tensor& shape_tensor = context->input(2);
+ int64 *shape = (int64 *)shape_tensor.tensor_data().data();
+
+ OP_REQUIRES(context, (shape_tensor.dims() == 1),
+ errors::InvalidArgument("Shape tensor is not valid (dims != 1)"));
+ OP_REQUIRES(context, (shape_tensor.dim_size(0) >= 2),
+ errors::InvalidArgument("Shape tensor is not valid (dim_size(0) < 2)"));
+
+ int64 input_size = 1;
+ for (int i = 0; i < input_tensor.dims(); ++i) {
+ input_size *= input_tensor.dim_size(i);
+ }
+ int input_dims = shape_tensor.dim_size(0);
+ int cols = shape[input_dims - 1];
+ int batch_size = 1;
+ for (int i = 0; i < input_dims - 1; ++i) {
+ batch_size *= shape[i];
+ }
+ OP_REQUIRES(context, (input_size == batch_size * cols),
+ errors::InvalidArgument("input id is dense"));
+ int embedding_dims = weight_tensor->dim_size(1);
+ bool is_mean = (combiner_ == 1);
+
+ Tensor* output_tensor = NULL;
+ TensorShape output_shape({batch_size, embedding_dims});
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+ &output_tensor));
+ float *output = (float *)output_tensor->tensor_data().data();
+ ktfop::EmbeddingParams params(input,
+ batch_size,
+ cols,
+ weight,
+ embedding_dims,
+ is_mean);
+ int result = ktfop::FusedEmbedding(params, output);
+ OP_REQUIRES(context, (result == 0),
+ errors::InvalidArgument("Invalid argument, error code: ", result));
+ }
+
+private:
+ int combiner_;
+ std::string node_name;
+};
+
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbedding").Device(DEVICE_CPU), KPFusedEmbeddingOp);
+
+class KPFusedEmbeddingWithHashBucketOp : public OpKernel {
+ public:
+ explicit KPFusedEmbeddingWithHashBucketOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
+ node_name = context->def().name();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ float *weight;
+ const Tensor& input_tensor = context->input(0);
+ const Tensor* weight_tensor = &context->input(1);
+
+ if (weight_tensor->dtype() == DT_RESOURCE) {
+ Var* variable;
+ OP_REQUIRES_OK(context,
+ LookupResource(context, HandleFromInput(context, 1),
+ &variable));
+ core::ScopedUnref s(variable);
+ weight_tensor = variable->tensor();
+ OP_REQUIRES(context, weight_tensor->dtype() == DT_FLOAT,
+ errors::InvalidArgument("Expect float weight in ",
+ node_name));
+ }
+
+ auto input = input_tensor.flat<tstring>();
+ weight = (float *)weight_tensor->tensor_data().data();
+ int64_t batch = input_tensor.dim_size(0);
+ int64_t embedding_dims = weight_tensor->dim_size(1);
+ uintptr_t cstr_addresses[batch];
+ size_t cstr_length[batch];
+ for (int i = 0; i < batch; ++i) {
+ cstr_addresses[i] = reinterpret_cast<uintptr_t>(input(i).c_str());
+ cstr_length[i] = input(i).length();
+ }
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({batch, embedding_dims}),
+ &output_tensor));
+ float *output = (float *)output_tensor->tensor_data().data();
+ ktfop::EmbeddingParamsWithHash params(cstr_addresses,
+ cstr_length,
+ batch,
+ weight,
+ num_buckets_,
+ embedding_dims);
+ int result = ktfop::FusedEmbeddingWithHashBucket(params, output);
+ OP_REQUIRES(context, (result == 0),
+ errors::InvalidArgument("Invalid argument, error code: ", result));
+ }
+
+ private:
+ int64_t num_buckets_;
+ std::string node_name;
+};
+REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingWithHashBucket").Device(DEVICE_CPU), KPFusedEmbeddingWithHashBucketOp);
+} // namespace tensorflow
new file mode 100644
@@ -0,0 +1,52 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
+#include "ktfop.h"
+
+namespace tensorflow {
+class KPSoftmaxOp : public OpKernel {
+ public:
+ explicit KPSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& logits_in = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in.shape()),
+ errors::InvalidArgument("logits must have >= 1 dimension, got ",
+ logits_in.shape().DebugString()));
+ Tensor* softmax_out = nullptr;
+ OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
+ {0}, 0, logits_in.shape(), &softmax_out));
+ if (logits_in.NumElements() > 0) {
+ typename TTypes<float>::ConstMatrix input = logits_in.flat_inner_dims<float>();
+ float* input_data = (float *)logits_in.data();
+ float* output_data = (float *)softmax_out->data();
+ int result = ktfop::Softmax(input_data, output_data, input.dimension(0), input.dimension(1));
+ OP_REQUIRES(context, (result == 0),
+ errors::InvalidArgument("Invalid argument, error code: ", result));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("KPSoftmax").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ KPSoftmaxOp);
+
+} // namespace tensorflow
\ No newline at end of file
@@ -42,6 +42,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/einsum_op_util.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
@@ -466,6 +471,16 @@ struct EinsumHelper {
set_zero(ctx->eigen_device<Device>(), output->flat<T>());
return OkStatus();
}
+#if defined(ENABLE_KDNN)
+ if (IsKDNNEnabled()
+ && std::is_same<T, float>::value
+ && inputs.size() == 2
+ && inputs[0].dims() >= 2 && inputs[0].dims() <= 5
+ && inputs[1].dims() >= 2 && inputs[1].dims() <= 5) {
+ kdnnBatchGemm(ctx, inputs[0], inputs[1], output, trans_x, trans_y);
+ return OkStatus();
+ }
+#endif
Tensor output_reshaped;
TF_RETURN_IF_ERROR(
ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
@@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/matmul_autotune.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/port.h"
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
#include "tsl/framework/contraction/eigen_contraction_kernel.h"
@@ -69,6 +70,10 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -93,6 +98,16 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
errors::InvalidArgument("_FusedMatMul doesn't support DT_HALF "
"data type on CPU devices."));
+#if defined(ENABLE_KDNN)
+ bool transpose_a_ = dim_pair[0].first == 0;
+ bool transpose_b_ = dim_pair[0].second == 1;
+ bool fusion_relu = fusion == FusedComputationType::kBiasAddWithRelu;
+ bool kdnn_enable_fusion = (fusion == FusedComputationType::kBiasAdd) || fusion_relu;
+ if (IsKDNNEnabled() && std::is_same<T, float>::value && kdnn_enable_fusion && !transpose_a_) {
+ kdnnFusedGemm(context, a, b, output, fusion_relu, transpose_a_, transpose_b_);
+ return;
+ }
+#endif
auto lhs = a.matrix<T>();
auto rhs = b.matrix<T>();
auto out = output->matrix<T>();
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/util/matmul_autotune.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/work_sharder.h"
+#include "tensorflow/core/util/port.h"
#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
#include "tsl/framework/contraction/eigen_contraction_kernel.h"
@@ -64,6 +65,10 @@ limitations under the License.
#endif
#endif
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -927,6 +932,24 @@ class BaseBatchMatMulOp : public OpKernel {
f(ctx->eigen_device<Device>(), out->flat<Tout>());
return;
}
+#if defined(ENABLE_KDNN)
+ int dims = in0.shape().dims();
+ bool trans_x = (adj_x_ || trans_x_);
+ bool trans_y = (adj_y_ || trans_y_);
+
+ if (IsKDNNEnabled()
+ && std::is_same<Ta, float>::value
+ && std::is_same<Tb, float>::value
+ && std::is_same<Tout, float>::value
+ && dims >= 2 && dims <= 5) {
+ if (batch_size == 1) {
+ kdnnGemm(ctx, in0_reshaped, in1_reshaped, out, trans_x, trans_y);
+ } else {
+ kdnnBatchGemm(ctx, in0, in1, out, trans_x, trans_y);
+ }
+ return;
+ }
+#endif
Tensor out_reshaped;
OP_REQUIRES(ctx,
out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
@@ -25,6 +25,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/softmax_op_functor.h"
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -61,6 +65,15 @@ class SoftmaxOp : public OpKernel {
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, logits_in.shape(), &softmax_out));
if (logits_in.NumElements() > 0) {
+#if defined(ENABLE_KDNN)
+ if constexpr (std::is_same<T, float>::value) {
+ if (!log_ && IsKDNNEnabled()) {
+ kdnnSoftmaxOp<T>(context, logits_in, softmax_out);
+ return;
+ }
+ }
+#endif
+
functor::SoftmaxFunctor<Device, T> functor;
functor(context->eigen_device<Device>(), logits_in.flat_inner_dims<T>(),
softmax_out->flat_inner_dims<T>(), log_);
@@ -24,6 +24,11 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/bfloat16.h"
+#include "tensorflow/core/util/port.h"
+
+#if defined(ENABLE_KDNN)
+#include "kdnn_adapter.h"
+#endif
namespace tensorflow {
@@ -133,6 +138,28 @@ class SparseTensorDenseMatMulOp : public OpKernel {
return;
}
+#if defined(ENABLE_KDNN)
+#define KDNN_ADJOINT(ADJ_A, ADJ_B) \
+ if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
+ Status functor_status = functor::KDNNSparseMatMulFunctor< \
+ CPUDevice, float, Tindices, ADJ_A, \
+ ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<float>(), \
+ a_indices->matrix<Tindices>(), a_values->vec<float>(), \
+ b->matrix<float>()); \
+ OP_REQUIRES_OK(ctx, functor_status); \
+ }
+ const int int32max = std::numeric_limits<int>::max();
+ if (IsKDNNEnabled() && std::is_same<T, float>::value && adjoint_a_ == false &&
+ FastBoundsCheck(inner_left, int32max) &&
+ FastBoundsCheck(inner_right, int32max) &&
+ FastBoundsCheck(outer_left, int32max)) {
+ KDNN_ADJOINT(false, false);
+ KDNN_ADJOINT(false, true);
+ return;
+ }
+#undef KDNN_ADJOINT
+#endif
+
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
@@ -355,6 +382,56 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
}
};
+#if defined(ENABLE_KDNN)
+template <typename Tindices, bool ADJ_A, bool ADJ_B>
+struct KDNNSparseMatMulFunctor<CPUDevice, float, Tindices, ADJ_A, ADJ_B> {
+ static const std::size_t kNumVectorize = 32;
+
+ static Status Compute(const CPUDevice& d, typename TTypes<float>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<float>::ConstVec a_values,
+ typename TTypes<float>::ConstMatrix b) {
+ const std::size_t nnz = a_values.size();
+ const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
+ const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));
+ const int lhs_index_a = ADJ_A ? 1 : 0;
+ const int rhs_index_a = ADJ_A ? 0 : 1;
+
+ out.setZero();
+
+ if (rhs_right < kNumVectorize) {
+ // Disable vectorization if the RHS of output is too small
+ auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
+
+ for (std::size_t i = 0; i < nnz; ++i) {
+ const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
+ const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
+ if (!FastBoundsCheck(k, lhs_right)) {
+ return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
+ }
+ if (!FastBoundsCheck(m, out.dimension(0))) {
+ return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));
+ }
+ const float a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i);
+ for (std::size_t n = 0; n < rhs_right; ++n) {
+ const float b_value = maybe_adjoint_b(k, n);
+ out(m, n) += a_value * b_value;
+ }
+ }
+ } else {
+ const float* b_data = b.data();
+ Eigen::Tensor<float, 2, Eigen::ColMajor> col_major_conj_b;
+ if (ADJ_B) {
+ Eigen::array<int, 2> shuffle(1, 0);
+ col_major_conj_b = b.swap_layout().shuffle(shuffle).conjugate().eval();
+ b_data = col_major_conj_b.data();
+ }
+ kdnnSparseMatmul<Tindices>(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b_data);
+ }
+ return OkStatus();
+ }
+};
+#endif
} // namespace functor
} // namespace tensorflow
@@ -35,6 +35,17 @@ struct SparseTensorDenseMatMulFunctor {
typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
};
+#if defined(ENABLE_KDNN)
+template <typename Device, typename T, typename Tindices, bool ADJ_A,
+ bool ADJ_B>
+struct KDNNSparseMatMulFunctor {
+ static EIGEN_ALWAYS_INLINE Status Compute(
+ const Device& d, typename TTypes<T>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
+};
+#endif
+
template <typename MATRIX, bool ADJ>
class MaybeAdjoint;
@@ -18,7 +18,14 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
-
+load(
+ "//third_party/ktfop:build_defs.bzl",
+ "if_ktfop",
+)
+load(
+ "//third_party/fused_embedding:build_defs.bzl",
+ "if_fused_embedding",
+)
# A lot of packages try to minimize binary size by depending on individual ops,\
# so they need access here.
package(
@@ -60,6 +67,8 @@ tf_gen_op_libs(
"functional_ops",
"image_ops",
"io_ops",
+ "ktfop_ops",
+ "fused_embedding_ops",
"linalg_ops",
"list_ops",
"map_ops",
@@ -344,6 +353,10 @@ cc_library(
}) + if_mkl([
":mkl_array_ops_op_lib",
":mkl_nn_ops_op_lib",
+ ]) + if_ktfop([
+ ":ktfop_ops_op_lib",
+ ]) + if_fused_embedding([
+ ":fused_embedding_ops_op_lib",
]),
alwayslink = 1,
)
new file mode 100644
@@ -0,0 +1,48 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPLookupEmbeddingByHash")
+ .Input("lookup: string")
+ .Input("weights: T_weight")
+ .Attr("num_buckets: int >= 1")
+ .Attr("combiner: int")
+ .Attr("T_weight: {resource, float}")
+ .Output("output: float")
+ .SetShapeFn([](InferenceContext* ctx) {
+ ShapeHandle temp;
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &temp));
+ DimensionHandle emb_size_dim = ctx->UnknownDim();
+ DimensionHandle batch_dim = ctx->UnknownDim();
+
+ ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+ ctx->set_output(0, output_shape);
+
+ return OkStatus();
+ });
+
+} // namespace tensorflow
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,81 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+REGISTER_OP("KPFusedEmbedding")
+ .Input("weights: float")
+ .Input("lookup: int64")
+ .Input("dense_shape: int64")
+ .Input("indices: int64")
+ .Output("output: float")
+ .Attr("combiner: int")
+
+ .SetShapeFn([](InferenceContext* ctx) {
+ ShapeHandle temp;
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(1), 1, &temp));
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(3), 2, &temp));
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(2), 1, &temp));
+ ShapeHandle emb_var_shape;
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 2, &emb_var_shape));
+
+ DimensionHandle emb_size_dim = ctx->Dim(emb_var_shape, 1);
+ DimensionHandle batch_dim = ctx->UnknownDim();
+
+ ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+ ctx->set_output(0, output_shape);
+
+ return OkStatus();
+ });
+
+REGISTER_OP("KPFusedEmbeddingWithHashBucket")
+ .Input("lookup: string")
+ .Input("weights: T_weight")
+ .Attr("num_buckets: int >= 1")
+ .Attr("combiner: int")
+ .Attr("T_weight: {resource, float}")
+ .Output("output: float")
+ .SetShapeFn([](InferenceContext* ctx) {
+ ShapeHandle temp;
+ TF_RETURN_IF_ERROR(ctx->WithRank(ctx->input(0), 1, &temp));
+ DimensionHandle emb_size_dim = ctx->UnknownDim();
+ DimensionHandle batch_dim = ctx->UnknownDim();
+
+ ShapeHandle output_shape = ctx->MakeShape({batch_dim, emb_size_dim});
+ ctx->set_output(0, output_shape);
+
+ return OkStatus();
+ });
+
+REGISTER_OP("KPSoftmax")
+ .Input("logits: T")
+ .Output("softmax: T")
+ .Attr("T: {float}")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
+ });
+
+} // namespace tensorflow
\ No newline at end of file
@@ -30,6 +30,7 @@ using tsl::port::NUMAGetThreadNodeAffinity;
using tsl::port::NUMAMalloc;
using tsl::port::NUMANumNodes;
using tsl::port::NUMASetThreadNodeAffinity;
+using tsl::port::ThreadAffinity;
} // namespace port
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PLATFORM_NUMA_H_
@@ -17,6 +17,13 @@ option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto";
+enum ExecutorPolicy {
+ // Tensorflow default executor
+ USE_NORMAL_EXECUTOR = 0;
+ // Kunpeng batch scheduling executor
+ USE_BATCH_SCHEDULING_EXECUTOR = 1;
+}
+
message GPUOptions {
// Fraction of the total GPU memory to allocate for each process.
// 1 means to allocate all of the GPU memory, 0.5 means the process
@@ -716,7 +723,8 @@ message ConfigProto {
Experimental experimental = 16;
- // Next: 18
+ bool use_batch_op_scheduling = 18;
+ // Next: 19
}
// Options for a single Run() call.
@@ -148,4 +148,31 @@ bool IsZenDnnEnabled() {
#endif // !AMD_ZENDNN
}
+bool IsKDNNEnabled() {
+#ifndef ENABLE_KDNN
+ return false;
+#else
+ static absl::once_flag once;
+ static bool KDNN_enabled = true;
+ absl::call_once(once, [&] {
+ auto status = ReadBoolFromEnvVar("TF_ENABLE_KDNN_OPTS", KDNN_enabled,
+ &KDNN_enabled);
+
+ if (!status.ok()) {
+ LOG(WARNING) << "TF_ENABLE_KDNN_OPTS is not set to either '0', 'false',"
+ << " '1', or 'true'. Using the default setting: "
+ << KDNN_enabled;
+ }
+ if (KDNN_enabled) {
+ LOG(INFO) << "KDNN custom operations are on. "
+ << "You may see slightly different numerical results due to "
+ << "floating-point round-off errors from different computation "
+ << "orders. To turn them off, set the environment variable "
+ << "`TF_ENABLE_KDNN_OPTS=0`.";
+ }
+ });
+ return KDNN_enabled;
+#endif // !KDNN
+}
+
} // namespace tensorflow
@@ -47,6 +47,8 @@ bool IsMklEnabled();
// Returns true if TF_ENABLE_ZENDNN_OPTS is set to 1
bool IsZenDnnEnabled();
+// Returns true if TF_ENABLE_KDNN_OPTS is set to 1
+bool IsKDNNEnabled();
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_PORT_H_
Binary files a/tensorflow/lite/g3doc/images/build/build_workflow_diag.png and b/tensorflow/lite/g3doc/images/build/build_workflow_diag.png differ
new file mode 100644
new file mode 100644
@@ -0,0 +1,329 @@
+import os
+import time
+import shutil
+import glob
+import json
+import sys
+import re
+import struct
+import tensorflow.compat.v1 as tf
+from dataclasses import dataclass, field
+from typing import List, Callable, Dict, Any, Union, Optional
+from tensorflow.python.profiler.internal import _pywrap_profiler
+from tensorflow.python.profiler import profiler_v2 as profiler
+from tensorflow.python.profiler.profiler_v2 import ProfilerOptions
+from tensorflow.core.profiler.protobuf.xplane_pb2 import XSpace
+import numpy as np
+import multiprocessing as mp
+import pickle
+
+tf.disable_eager_execution()
+
+
+@dataclass
+class TestCase:
+ name: str
+ op_fn: Callable
+ input_fn: Callable
+ check_fn: Optional[Callable] = None
+ num_iters: int = 0
+ optimize_percent: int = 0
+ operator_name: str = ""
+ meta: Dict = field(default_factory=dict)
+
+class CheckFuncClass:
+ def check_fn_float(A, B, meta):
+ atol = meta.get('atol')
+ rtol = meta.get('rtol')
+ if rtol is None:
+ rtol = 1e-3
+ if atol is None:
+ atol = 1e-4
+ # ∣a−b∣≤ atol + rtol ×∣b∣
+ return np.allclose(A, B, rtol=rtol, atol=atol)
+
+ def check_fn_equal(A, B, meta):
+ np.testing.assert_array_equal(A, B) # 如果失败抛出异常, 如果成功无返回值,
+ return True
+
+class UniversalOpBenchmark:
+ def __init__(self, log_dir="bench_logs"):
+ self.log_dir = log_dir
+ self.results = []
+
+ def _subprocess_worker(
+ self,
+ queue: mp.Queue,
+ func: Callable,
+ is_kdnn_enable: bool,
+ raw_inputs: list,
+ func_args: TestCase,
+ temp_dir: str
+ ) -> None:
+ """
+ 子进程工作函数:创建临时目录,执行目标函数,将结果/异常放入队列
+ """
+ try:
+ # 1. 自动创建临时目录(不存在则创建,存在则不报错)
+ os.makedirs(temp_dir, exist_ok=True)
+
+ # 2. 执行目标函数A
+ result = func(is_kdnn_enable, raw_inputs, func_args)
+ try:
+ pickle.dumps(result)
+ except Exception as e:
+ raise Exception(f"返回结果无法序列化,无法传递给主进程:{str(e)}") from e
+ # 3. 将执行结果放入队列(传递给主进程)
+ try:
+ queue.put(('success', result), block=True, timeout=10)
+ except mp.Queue.Full:
+ raise Exception("队列缓冲区已满,无法写入执行结果")
+
+ except Exception as e:
+ # 捕获所有异常,传递给主进程
+ queue.put(('error', e))
+ return
+
+ def run_func_in_subprocess(
+ self,
+ func_A: Callable,
+ is_kdnn_enable: bool,
+ raw_inputs: list,
+ func_args: TestCase,
+ temp_dir: str = "./temp_workdir"
+ ) -> str | np.ndarray:
+
+ # 2. 创建进程间通信队列(用于传递子进程执行结果/异常)
+ result_queue = mp.Queue()
+
+ # 3. 构建子进程
+ sub_process = mp.Process(
+ target=self._subprocess_worker,
+ args=(result_queue, func_A, is_kdnn_enable, raw_inputs, func_args, temp_dir)
+ )
+
+ # 4. 启动并等待子进程执行完成
+ sub_process.start()
+
+ # 5. 从队列中获取子进程执行结果
+ process_status, result = result_queue.get()
+ if process_status == 'error':
+ raise Exception(f"子进程中函数执行失败:{str(result)}") from result
+
+ sub_process.join() # 阻塞主进程,等待子进程退出
+ return result
+
+ def _get_event_name_by_metadata_id(self, metadata_id, plane):
+ if not hasattr(plane, 'event_metadata'):
+ return f"unknown_operator_{metadata_id}"
+
+ metadata=plane.event_metadata[metadata_id]
+ if metadata.id == metadata_id:
+ return metadata.name if (metadata.name and metadata.name.strip()) else f"unknown_operator_{metadata_id}"
+
+ # 未匹配到返回默认名称
+ return f"unknown_operator_{metadata_id}"
+
+ def parse_xplane_file_optimized(self, dir_path):
+ if not os.path.isdir(dir_path):
+ print(f"dir is not exist: {dir_path}")
+ return None
+
+ xplane_files = []
+ for file_name in os.listdir(dir_path):
+ if file_name.endswith(".xplane.pb"):
+ xplane_files.append(os.path.join(dir_path, file_name))
+
+ if not xplane_files:
+ print(f"file is not exist in {dir_path}")
+ return None
+ elif len(xplane_files) > 1:
+ print(f"Found multiple.xplane.pb files: {xplane_files}; the first file will be used.")
+
+ file_path = xplane_files[0]
+ with open(file_path, 'rb') as f:
+ content = f.read()
+
+ xspace = XSpace()
+ xspace.ParseFromString(content)
+ if len(xspace.planes) > 0:
+ return xspace
+
+ print("parse_xplane_file_optimized failed")
+ return None
+
+ def get_average_wall_during(self, input_dir, operator_name, test_case: TestCase):
+ # 1.提取数据
+ xplane_data = self.parse_xplane_file_optimized(input_dir)
+ if not xplane_data:
+ return 0.0
+
+ op_name_list = []
+ # 2. 获取wall_duration
+ wall_durations_us_list = []
+ for plane in xplane_data.planes:
+ if hasattr(plane, 'lines') and len(plane.lines) > 0:
+ for line in plane.lines:
+ if hasattr(line, 'events') and len(line.events) > 0:
+ for event in line.events:
+ # 1. 获取事件时长(ps转换为us)
+ duration_ps = event.duration_ps if hasattr(event, 'duration_ps') else 0
+ if duration_ps <= 0:
+ continue # 过滤无效时长事件
+ #转换成us
+ wall_duration_us = duration_ps / 1000000
+
+ # 2. 通过metadata_id获取算子名称
+ metadata_id = event.metadata_id if hasattr(event, 'metadata_id') else 0
+ op_name = self._get_event_name_by_metadata_id(metadata_id, plane)
+ if op_name not in op_name_list:
+ op_name_list.append(op_name)
+ if operator_name in op_name:
+ wall_durations_us_list.append(wall_duration_us)
+ # print(f"解析到的算子名称列表:{op_name_list}") # 打印所有算子名称
+ if test_case.num_iters != np.size(wall_durations_us_list):
+ print(f"解析到有效数据{np.size(wall_durations_us_list)}条, 预期{test_case.num_iters}条")
+ return 0.0
+
+ return float(np.mean(wall_durations_us_list)), np.var(wall_durations_us_list, ddof=1)
+
+ def parse_performance_data(self, kdnn_enable, raw_inputs, test_case: TestCase):
+ print(f"Testing: {test_case.name} ...")
+ os.environ['TF_ENABLE_KDNN_OPTS'] = str(kdnn_enable)
+
+ placeholders = []
+ feed_dict = {}
+
+ for i, val in enumerate(raw_inputs):
+ np_val = np.array(val)
+
+ p = tf.placeholder(
+ dtype=tf.as_dtype(np_val.dtype),
+ shape=np_val.shape,
+ name=f"input_{i}"
+ )
+ placeholders.append(p)
+ feed_dict[p] = np_val
+
+ res_node = test_case.op_fn(placeholders, test_case.meta)
+ config = tf.ConfigProto(
+ inter_op_parallelism_threads=16,
+ intra_op_parallelism_threads=16
+ )
+ options = ProfilerOptions(
+ host_tracer_level=2,
+ device_tracer_level=1,
+ python_tracer_level=0 # 建议保持 0 以减少对 Kernel 测量的干扰
+ )
+ with tf.Session(config=config) as sess:
+ for _ in range(10):
+ sess.run(res_node, feed_dict=feed_dict)
+
+ profiler.start(self.log_dir, options)
+ time.sleep(3)
+ for _ in range(test_case.num_iters):
+ sess.run(res_node, feed_dict=feed_dict)
+ profiler.stop()
+
+ try:
+ kdnn_status = os.environ.get('TF_ENABLE_KDNN_OPTS', '1')
+ tag = f"{test_case.name}_kdnn_{kdnn_status}" # 结果为 kdnn_0 或 kdnn_1
+ profile_root = os.path.join(self.log_dir, "plugins", "profile")
+ timestamp_dirs = [d for d in os.listdir(profile_root)
+ if os.path.isdir(os.path.join(profile_root, d)) and d.startswith("202")]
+
+ if timestamp_dirs:
+ src_dir = os.path.join(profile_root, timestamp_dirs[0])
+ # 目标路径: bench_logs/MatMul_Test/plugins/profile/kdnn_1
+ target_dir = os.path.join(profile_root, tag)
+
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+
+ os.rename(src_dir, target_dir)
+ print(f"✅ Profile 已固化至: {target_dir}")
+ return target_dir
+
+ # 可选:清理掉空的时间戳父目录或日志文件
+ # (TensorFlow 有时会留下一些空的 event 文件,通常不影响解析)
+ except Exception as e:
+ print(f"⚠️ 路径固化失败: {e}")
+
+ def run_performance_test(self, test_case: TestCase):
+ if test_case.num_iters == 0:
+ return True
+
+ print(f"Performance testing: {test_case.name} ...")
+
+ optimize_percent = test_case.optimize_percent
+ operator_name = test_case.operator_name
+ raw_inputs = test_case.input_fn()
+ if not isinstance(raw_inputs, (list, tuple)):
+ raw_inputs = [raw_inputs]
+
+ no_kdnn_path = self.run_func_in_subprocess(self.parse_performance_data, 0, raw_inputs, test_case)
+ no_kdnn_wall_during, _ = self.get_average_wall_during(no_kdnn_path, operator_name, test_case)
+ kdnn_path = self.run_func_in_subprocess(self.parse_performance_data, 1, raw_inputs, test_case)
+ kdnn_wall_during, _ = self.get_average_wall_during(kdnn_path, operator_name, test_case)
+ if kdnn_wall_during <= 0.0 or no_kdnn_wall_during <= 0.0:
+ print(f"⚠️ 获取平均运行时长失败, 无法进行比较")
+ return False
+ real_percent = 100 * (no_kdnn_wall_during/kdnn_wall_during - 1)
+ if real_percent < optimize_percent:
+ raise Exception(f"⚠️ 性能提升{real_percent:.2f}%, 低于预期{optimize_percent}%, KDNN开启状态下耗时{kdnn_wall_during:.2f}us, 关闭状态下耗时{no_kdnn_wall_during:.2f}us")
+ return False
+
+ print(f"性能测试通过, 性能提升{real_percent:.2f}%, KDNN开启状态下耗时{kdnn_wall_during:.2f}us, 关闭状态下耗时{no_kdnn_wall_during:.2f}us")
+ return True
+
+ def run_function_test(self, test_case: TestCase):
+ print(f"Function testing: {test_case.name} ...")
+
+ raw_inputs = test_case.input_fn()
+ if not isinstance(raw_inputs, list):
+ raw_inputs = [raw_inputs]
+
+ def execute_variant(enable_kdnn, raw_inputs, test_case: TestCase):
+ os.environ['TF_ENABLE_KDNN_OPTS'] = str(enable_kdnn)
+ tf.reset_default_graph()
+
+ placeholders = []
+ feed_dict = {}
+
+ for i, val in enumerate(raw_inputs):
+ np_val = np.array(val)
+
+ p = tf.placeholder(
+ dtype=tf.as_dtype(np_val.dtype),
+ shape=np_val.shape,
+ name=f"input_{i}"
+ )
+ placeholders.append(p)
+ feed_dict[p] = np_val
+ res_node = test_case.op_fn(placeholders, test_case.meta)
+ config = tf.ConfigProto(
+ inter_op_parallelism_threads=16,
+ intra_op_parallelism_threads=16
+ )
+ options = ProfilerOptions(
+ host_tracer_level=2,
+ device_tracer_level=1,
+ python_tracer_level=0 # 建议保持 0 以减少对 Kernel 测量的干扰
+ )
+ with tf.Session(config=config) as sess:
+ return sess.run(res_node, feed_dict=feed_dict)
+ return
+
+ kdnn_data = self.run_func_in_subprocess(execute_variant, 1, raw_inputs, test_case)
+ no_kdnn_data = self.run_func_in_subprocess(execute_variant, 0, raw_inputs, test_case)
+ if test_case.check_fn is None:
+ is_correct = CheckFuncClass.check_fn_float(no_kdnn_data, kdnn_data, test_case.meta)
+ else:
+ is_correct = test_case.check_fn(no_kdnn_data, kdnn_data, test_case.meta)
+
+ if not is_correct:
+ print(f"⚠️ 误差较大")
+ raise Exception(f"⚠️ 误差较大")
+ return False
+ print("功能测试通过")
+ return True
new file mode 100644
@@ -0,0 +1,51 @@
+import argparse
+import importlib
+import pkgutil
+import traceback
+import os
+
+def main():
+ parser = argparse.ArgumentParser(description="TF Op Benchmark Framework with KDNN Control")
+ parser.add_argument('--op', type=str, help='指定要运行的算子模块名')
+ parser.add_argument('--list', action='store_true', help='列出所有模块')
+ parser.add_argument('--performance_test', choices=['True', 'False'], default='True', help='是否运行性能测试模块')
+
+ args = parser.parse_args()
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' #日志太多, 关掉
+
+ import tensorflow as TF
+ from framework.runner import UniversalOpBenchmark
+ import ops
+
+ # 4. 扫描模块
+ available_ops = {
+ name: f'ops.{name}'
+ for loader, name, is_pkg in pkgutil.iter_modules(ops.__path__)
+ }
+
+ if args.list:
+ print("📁 可用算子列表:")
+ for name in available_ops: print(f" - {name}")
+ return
+
+ target_modules = [available_ops[args.op]] if args.op else list(available_ops.values())
+ root_log_dir = "bench_logs"
+ bench = UniversalOpBenchmark(log_dir=root_log_dir)
+
+ for module_path in target_modules:
+ try:
+ module = importlib.import_module(module_path)
+ if hasattr(module, 'get_test_cases'):
+ cases = module.get_test_cases()
+ for case in cases:
+ bench.run_function_test(case)
+ if args.performance_test == "True":
+ for case in cases:
+ bench.run_performance_test(case)
+ except Exception as e:
+ print(f"❌ 运行模块 {module_path} 失败: {e}")
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,176 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+
+ def matmul_op(inputs, meta):
+ return tf.matmul(inputs[0], inputs[1], transpose_a=meta.get('trans_a', False), transpose_b=meta.get('trans_b', False))
+
+ def input_2D_float(m, k, n):
+ return [
+ np.random.uniform(0, 1, (m, k)).astype(np.float32),
+ np.random.uniform(0, 1, (k, n)).astype(np.float32),
+ ]
+
+ def input_3D_no_broadcast_float(b1, m, k, n):
+ return [
+ np.random.uniform(0, 1, (b1, m, k)).astype(np.float32),
+ np.random.uniform(0, 1, (b1, k, n)).astype(np.float32),
+ ]
+
+ def input_3D_broadcast_float(b1, m, k, n):
+ return [
+ np.random.uniform(0, 1, (b1, m, k)).astype(np.float32),
+ np.random.uniform(0, 1, (1, k, n)).astype(np.float32),
+ ]
+
+ def build_batch_matmul_input(shape_a, shape_b, trans_a, trans_b, dtype):
+ """
+ 构造满足矩阵乘法条件的两个输入 Tensor
+ """
+ # 如果转置,需要交换最后两个维度来生成原始数据
+ final_shape_a = list(shape_a)
+ if trans_a:
+ final_shape_a[-1], final_shape_a[-2] = final_shape_a[-2], final_shape_a[-1]
+
+ final_shape_b = list(shape_b)
+ if trans_b:
+ final_shape_b[-1], final_shape_b[-2] = final_shape_b[-2], final_shape_b[-1]
+
+ a = np.random.uniform(-1, 1, final_shape_a).astype(dtype)
+ b = np.random.uniform(-1, 1, final_shape_b).astype(dtype)
+ return [a, b]
+
+ class BatchMatMulTestCaseFactory:
+ @staticmethod
+ def create_func_test():
+ # 基础组合维度定义
+ dims_to_test = [2, 3, 4, 5]
+ trans_options = [True, False]
+ broadcast_options = [True, False]
+ dtypes = {"fp32": np.float32}
+
+ # 矩阵乘法的核心维度 M, K, N
+ M, K, N = 64, 32, 48
+
+ all_cases = []
+
+ for dim in dims_to_test:
+ for is_bc in broadcast_options:
+ # 2D 场景不存在广播概念,跳过重复
+ if dim == 2 and is_bc: continue
+
+ for ta in trans_options:
+ for tb in trans_options:
+ for dt_name, dt_val in dtypes.items():
+
+ # 构造形状逻辑
+ # A: (..., M, K), B: (..., K, N)
+ # 如果转置参数为真,op内部会处理,但输入生成需匹配
+ batch_dims_a = [2, 3, 1][:dim-2] if dim > 2 else []
+ if is_bc:
+ # 构造广播场景:A的batch维有1,B的batch维正常
+ batch_dims_a = [1] * (dim - 2)
+ batch_dims_b = [2] * (dim - 2)
+ else:
+ batch_dims_b = batch_dims_a
+
+ shape_a = tuple(batch_dims_a + [M, K])
+ shape_b = tuple(batch_dims_b + [K, N])
+
+ case_name = f"batch_matmul_{dim}D_{dt_name}_bc{is_bc}_ta{ta}_tb{tb}"
+
+ # 预绑定输入构造函数
+ input_fn = partial(
+ build_batch_matmul_input,
+ shape_a=shape_a,
+ shape_b=shape_b,
+ trans_a=ta,
+ trans_b=tb,
+ dtype=dt_val
+ )
+
+ all_cases.append(TestCase(
+ name=case_name,
+ op_fn=matmul_op,
+ input_fn=input_fn,
+ num_iters=0,
+ meta={
+ 'trans_a': ta, # 对应算子属性:是否转置第一个输入
+ 'trans_b': tb # 对应算子属性:是否转置第二个输入
+ }
+ ))
+ return all_cases
+ test_case = []
+ test_case.extend(BatchMatMulTestCaseFactory.create_func_test())
+ # perf test case
+ test_case.extend([
+ TestCase(
+ name="MatMul_2D_79_1570_256_False_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_2D_float, m=79, k=1570, n=256),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_2D_79_1570_128_False_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_2D_float, m=79, k=1570, n=128),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_2D_4480_32_16_False_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_2D_float, m=4480, k=32, n=16),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_2D_64_256_128_False_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_2D_float, m=64, k=256, n=128),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_2D_128_592_128_False_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_2D_float, m=128, k=592, n=128),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_3D_No_Broadcast_64_64_256_256_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_3D_no_broadcast_float, b1=64, m=64, k=256, n=256),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+ TestCase(
+ name="MatMul_3D_No_Broadcast_64_32_64_64_False",
+ op_fn=matmul_op,
+ input_fn=partial(input_3D_no_broadcast_float, b1=64, m=32, k=64, n=64),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "MatMul",
+ meta={'trans_a': False, 'trans_b': False},
+ ),
+
+ ])
+ return test_case
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,278 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase, CheckFuncClass
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+ def concat_op(inputs, meta):
+ return tf.concat(inputs, meta["axis"])
+
+ def build_18input_2D(dtype=np.float32):
+ batch = 150
+ shapes = [
+ (batch, 960),
+ (batch, 24),
+ (batch, 4),
+ (batch, 4),
+ (batch, 4),
+ (batch, 64),
+ (batch, 4),
+ (batch, 4),
+ (batch, 32),
+ (batch, 4),
+ (batch, 4),
+ (batch, 24),
+ (batch, 4),
+ (batch, 4),
+ (batch, 24),
+ (batch, 4),
+ (batch, 4),
+ (batch, 16),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_12input_2D(dtype=np.float32):
+ batch = 150
+ shapes = [
+ (batch, 192),
+ (batch, 288),
+ (batch, 64),
+ (batch, 48),
+ (batch, 48),
+ (batch, 8),
+ (batch, 48),
+ (batch, 48),
+ (batch, 8),
+ (batch, 256),
+ (batch, 128),
+ (batch, 128),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_8input_3D(dtype=np.float32):
+ batch = 16
+ seq = 64
+ shapes = [
+ (batch, seq, 32),
+ (batch, seq, 64),
+ (batch, seq, 16),
+ (batch, seq, 16),
+ (batch, seq, 16),
+ (batch, seq, 16),
+ (batch, seq, 16),
+ (batch, seq, 16),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_4input_2D(dtype=np.float32):
+ batch = 256
+ shapes = [
+ (batch, 192),
+ (batch, 288),
+ (batch, 64),
+ (batch, 48),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_3input_2D(dtype=np.float32):
+ shapes = [
+ (128, 1510),
+ (128, 36),
+ (128, 24),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_2input_2D(batch, input1_dim, input2_dim, dtype=np.float32):
+ shapes = [
+ (batch, input1_dim),
+ (batch, input2_dim),
+ ]
+ return [np.random.uniform(0, 1, shape).astype(dtype) for shape in shapes]
+
+ def build_dynamic_input(num_inputs, dtype, shape_template, concat_axis):
+ """
+ 通用输入构造器
+ :param num_inputs: 输入 Tensor 的个数 (1, 2, 8, 12, 18)
+ :param dtype: 数据类型 (np.float32, np.int32 等)
+ :param shape_template: 基础 3D 形状,例如 (batch, seq, feature)
+ :param concat_axis: 拼接的轴,用于微调形状确保可以 concat
+ """
+ inputs = []
+ for i in range(num_inputs):
+ curr_shape = list(shape_template)
+ # 为了模拟真实情况,我们可以让拼接轴上的长度略有不同
+ curr_shape[concat_axis] = np.random.randint(1, 10) # 如果需要动态长度可开启
+
+ if np.issubdtype(dtype, np.integer):
+ data = np.random.randint(0, 100, curr_shape).astype(dtype)
+ else:
+ data = np.random.uniform(0, 1, curr_shape).astype(dtype)
+ inputs.append(data)
+ return inputs
+
+ class ConcatTestCaseFactory:
+ @staticmethod
+ def create_func_test():
+ all_cases = []
+ input_counts = [1, 2, 8, 12, 18]
+ dtypes = {
+ "fp32": np.float32,
+ "bf16": np.uint16, # 工业界常用 uint16 模拟 bf16 存储
+ "int32": np.int32,
+ "int8": np.int8,
+ "uint8": np.uint8
+ }
+ axes = [0, 1, 2, -1, -2, -3]
+ base_3d_shape = (32, 64, 128) # 示例 3D 基础形状
+ for num in input_counts:
+ for dtype_name, dtype_val in dtypes.items():
+ for axis in axes:
+ case_name = f"concat_3D_{num}in_{dtype_name}_axis{axis}"
+
+ # 使用 partial 预绑定所有参数
+ input_fn = partial(
+ build_dynamic_input,
+ num_inputs=num,
+ dtype=dtype_val,
+ shape_template=base_3d_shape,
+ concat_axis=axis
+ )
+
+ case = TestCase(
+ name=case_name,
+ op_fn=concat_op,
+ input_fn=input_fn,
+ num_iters=0,
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': axis}
+ )
+ all_cases.append(case)
+ return all_cases
+
+ test_case = []
+ test_case.extend(ConcatTestCaseFactory.create_func_test())
+ # perf test case
+ test_case.extend([
+ TestCase(
+ name="concat_18input_2D_float",
+ op_fn=concat_op,
+ input_fn=partial(build_18input_2D, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_12input_2D_float",
+ op_fn=concat_op,
+ input_fn=partial(build_12input_2D, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_8input_3D_float",
+ op_fn=concat_op,
+ input_fn=partial(build_8input_3D, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_4input_2D_float",
+ op_fn=concat_op,
+ input_fn=partial(build_4input_2D, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_3input_2D_float",
+ op_fn=concat_op,
+ input_fn=partial(build_3input_2D, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_2input_2D_small_float",
+ op_fn=concat_op,
+ input_fn=partial(build_2input_2D, batch=128, input1_dim=400, input2_dim=400, dtype=np.float32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_18input_2D_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_18input_2D, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_12input_2D_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_12input_2D, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_8input_3D_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_8input_3D, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_4input_2D_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_4input_2D, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_3input_2D_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_3input_2D, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ TestCase(
+ name="concat_2input_2D_small_int32",
+ op_fn=concat_op,
+ input_fn=partial(build_2input_2D, batch=128, input1_dim=400, input2_dim=400, dtype=np.int32),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "ConcatV2",
+ check_fn=CheckFuncClass.check_fn_equal,
+ meta={'axis': -1}
+ ),
+ ])
+ return test_case
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,118 @@
+import tensorflow as tf
+import numpy as np
+from functools import partial
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+
+ def einsum_op(inputs, meta):
+ return tf.einsum(meta["label"], inputs[0], inputs[1])
+
+ def input_3D_bh_bsh_bs_float(B, H, S):
+ return [
+ np.random.uniform(0, 1, (B, H)).astype(np.float32),
+ np.random.uniform(0, 1, (B, S, H)).astype(np.float32),
+ ]
+
+ def input_3D_bs_bsh_bh_float(B, H, S):
+ return [
+ np.random.uniform(0, 1, (B, S)).astype(np.float32),
+ np.random.uniform(0, 1, (B, S, H)).astype(np.float32),
+ ]
+
+ def build_einsum_input(shapes, dtype):
+ """
+ 根据 einsum 表达式和形状列表构造输入
+ """
+ inputs = []
+ for shape in shapes:
+ if np.issubdtype(dtype, np.integer):
+ data = np.random.randint(0, 10, shape).astype(dtype)
+ else:
+ data = np.random.uniform(-1, 1, shape).astype(dtype)
+ inputs.append(data)
+ return inputs
+
+ class EinsumTestCaseFactory:
+ @staticmethod
+ def create_func_test():
+ all_cases = []
+ # 定义 B, H, S 的测试点
+ B_list = [1, 79, 256]
+ H_list = [1, 79, 128]
+ S_list = [1, 79, 512]
+
+ # 1. Scoring 场景: bh, bsh -> bs
+ for b in B_list:
+ for h in H_list:
+ for s in S_list:
+ eq = "bh,bsh->bs"
+ shapes = [(b, h), (b, s, h)]
+ all_cases.append(EinsumTestCaseFactory._create_case(
+ "Scoring", eq, shapes, np.float32, b, h, s
+ ))
+
+ # 2. Pooling 场景: bs, bsh -> bh
+ for b in B_list:
+ for h in H_list:
+ for s in S_list:
+ eq = "bs,bsh->bh"
+ shapes = [(b, s), (b, s, h)]
+ all_cases.append(EinsumTestCaseFactory._create_case(
+ "Pooling", eq, shapes, np.float32, b, h, s
+ ))
+
+ # 3. 批量矩阵乘法场景: bij, bjk -> bik
+ ikj_variants = [
+ (64, 128, 64), (79, 128, 64), (1, 128, 64),
+ (128, 1, 64), (128, 128, 1)
+ ]
+ for b in B_list:
+ for i, k, j in ikj_variants:
+ eq = "bij,bjk->bik"
+ shapes = [(b, i, j), (b, j, k)]
+ all_cases.append(EinsumTestCaseFactory._create_case(
+ "BatchMatMul", eq, shapes, np.float32, b, i, j, k
+ ))
+
+ return all_cases
+
+ @staticmethod
+ def _create_case(scene, eq, shapes, dtype, *args):
+ dims_str = "_".join([str(a) for a in args])
+ case_name = f"einsum_{scene}_{eq.replace(',', '_').replace('->', '_')}_params_{dims_str}"
+
+ return TestCase(
+ name=case_name,
+ op_fn=einsum_op,
+ input_fn=partial(build_einsum_input, shapes=shapes, dtype=dtype),
+ num_iters=0,
+ operator_name="Einsum",
+ meta={'label': eq}
+ )
+
+ test_case = []
+ test_case.extend(EinsumTestCaseFactory.create_func_test())
+ # perf test case
+ test_case.extend([
+ TestCase(
+ name="Einsum_3D_float_bh_bsh_bs_128_256_512",
+ op_fn=einsum_op,
+ input_fn=partial(input_3D_bh_bsh_bs_float, B=128, H=256, S=512),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "Einsum",
+ meta={'label': "bh,bsh->bs"}
+ ),
+ TestCase(
+ name="Einsum_3D_float_bs_bsh_bh_128_128_512",
+ op_fn=einsum_op,
+ input_fn=partial(input_3D_bs_bsh_bh_float, B=128, H=128, S=512),
+ num_iters=1000,
+ optimize_percent = 5,
+ operator_name = "Einsum",
+ meta={'label': "bs,bsh->bh"}
+ ),
+ ])
+ return test_case
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,67 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+from framework.runner import CheckFuncClass
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+
+ def reduce_floormod_op(inputs, meta):
+ return tf.math.floormod(inputs[0], inputs[1])
+
+ def build_input_1D_float():
+ shape = [500000]
+ return [np.random.uniform(-1, 1, shape).astype(np.float32) for i in range(2)]
+
+ def build_input_3D_float():
+ shape = [64, 64, 128]
+ return [np.random.uniform(0, 1, shape).astype(np.float32) for i in range(2)]
+
+ def build_input_1D_int64():
+ shape = [1000000]
+ return [np.random.randint(-2**63, 2**63-1, size = shape, dtype=np.int64) for i in range(2)]
+
+ def build_input_3D_int64():
+ shape = [128, 64, 128]
+ return [np.random.randint(0, 2**63-1, size = shape, dtype=np.int64) for i in range(2)]
+
+ return [
+ TestCase(
+ name="floormod_input_1D_float",
+ op_fn=reduce_floormod_op,
+ input_fn=build_input_3D_float,
+ num_iters=1000,
+ optimize_percent = 150,
+ operator_name = "FloorMod:FloorMod",
+ meta={}
+ ),
+ TestCase(
+ name="floormod_input_3D_float",
+ op_fn=reduce_floormod_op,
+ input_fn=build_input_3D_float,
+ num_iters=1000,
+ optimize_percent = 150,
+ operator_name = "FloorMod:FloorMod",
+ meta={}
+ ),
+ TestCase(
+ name="floormod_input_1D_int64",
+ op_fn=reduce_floormod_op,
+ input_fn=build_input_1D_int64,
+ check_fn=CheckFuncClass.check_fn_equal,
+ num_iters=1000,
+ optimize_percent = 25,
+ operator_name = "FloorMod:FloorMod",
+ meta={}
+ ),
+ TestCase(
+ name="floormod_input_3D_int64",
+ op_fn=reduce_floormod_op,
+ input_fn=build_input_3D_int64,
+ check_fn=CheckFuncClass.check_fn_equal,
+ num_iters=1000,
+ optimize_percent = 25,
+ operator_name = "FloorMod:FloorMod",
+ meta={}
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,44 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+
+ def sigmoid_op(inputs, meta):
+ return tf.nn.sigmoid(inputs)
+
+ def build_input_1D_float():
+ shape = [1024]
+ return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+ def build_input_3D_float():
+ shape = [128, 64, 128]
+ return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+ def build_input_empty():
+ return [np.array([]).astype(np.float32)]
+
+
+ return [
+ TestCase(
+ name="sigmoid_input_1D_float",
+ op_fn=sigmoid_op,
+ num_iters=0,
+ input_fn=build_input_1D_float
+ ),
+ TestCase(
+ name="sigmoid_input_3D_float",
+ op_fn=sigmoid_op,
+ num_iters=1000,
+ input_fn=build_input_3D_float,
+ optimize_percent = 50,
+ operator_name = "Sigmoid:Sigmoid"
+ ),
+ TestCase(
+ name="sigmoid_input_empty",
+ op_fn=sigmoid_op,
+ num_iters=0,
+ input_fn=build_input_empty
+ ),
+ ]
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,53 @@
+import tensorflow as tf
+import numpy as np
+from framework.runner import TestCase
+
+def get_test_cases():
+ """每个算子文件都统一实现这个接口"""
+
+ def softmax_op(inputs, meta):
+ return tf.nn.softmax(inputs, axis=meta["axis"])
+
+ def logsoftmax_op(inputs, meta):
+ return tf.nn.log_softmax(inputs, axis=meta["axis"])
+
+ def build_input_2D_float():
+ shape = [1024, 128]
+ return [np.random.uniform(-100, 100, shape).astype(np.float32)]
+
+ def build_input_3D_float():
+ shape = [64, 64, 128]
+ return [np.random.uniform(-10, 10, shape).astype(np.float32)]
+
+ return [
+ TestCase(
+ name="softmax_input_2D_float",
+ op_fn=softmax_op,
+ input_fn=build_input_2D_float,
+ num_iters=0,
+ meta={'axis': 2,}
+ ),
+ TestCase(
+ name="softmax_input_3D_float_axis_1",
+ op_fn=softmax_op,
+ input_fn=build_input_3D_float,
+ num_iters=0,
+ meta={'axis': 1,}
+ ),
+ TestCase(
+ name="softmax_input_3D_float_axis_2",
+ op_fn=softmax_op,
+ input_fn=build_input_3D_float,
+ num_iters=1000,
+ optimize_percent = 200,
+ operator_name = "Softmax:Softmax",
+ meta={'axis': 2,}
+ ),
+ TestCase(
+ name="logsoftmax_input_2D_float",
+ op_fn=logsoftmax_op,
+ input_fn=build_input_2D_float,
+ num_iters=0,
+ meta={'axis': -1,}
+ ),
+ ]
\ No newline at end of file
@@ -61,6 +61,10 @@ load(
"//third_party/compute_library:build_defs.bzl",
"if_enable_acl",
)
+load(
+ "//third_party/KDNN:build_defs.bzl",
+ "if_enable_kdnn",
+)
load(
"//third_party/llvm_openmp:openmp.bzl",
"windows_llvm_openmp_linkopts",
@@ -465,6 +469,7 @@ def tf_copts(
if_mkldnn_aarch64_acl(["-DDNNL_AARCH64_USE_ACL=1"]) +
if_mkldnn_aarch64_acl_openmp(["-DENABLE_ONEDNN_OPENMP"]) +
if_zendnn(["-DAMD_ZENDNN"]) +
+ if_enable_kdnn(["-DENABLE_KDNN"]) +
if_enable_acl(["-DXLA_CPU_USE_ACL=1", "-fexceptions"]) +
if_android_arm(["-mfpu=neon", "-fomit-frame-pointer"]) +
if_linux_x86_64(["-msse3"]) +
@@ -26,6 +26,8 @@ load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
load("//third_party/ducc:workspace.bzl", ducc = "repo")
load("//third_party/eigen3:workspace.bzl", eigen3 = "repo")
load("//third_party/farmhash:workspace.bzl", farmhash = "repo")
+load("//third_party/ktfop:workspace.bzl", ktfop = "repo")
+load("//third_party/kblas:workspace.bzl", kblas = "repo")
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo")
load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo")
@@ -67,6 +69,8 @@ def _initialize_third_party():
dlpack()
eigen3()
farmhash()
+ kblas()
+ ktfop()
flatbuffers()
gemmlowp()
hexagon_nn()
@@ -803,7 +807,7 @@ def _tf_repositories():
name = "upb",
sha256 = "61d0417abd60e65ed589c9deee7c124fe76a4106831f6ad39464e1525cef1454",
strip_prefix = "upb-9effcbcb27f0a665f9f345030188c0b291e32482",
- patch_file = ["//third_party/grpc:upb_platform_fix.patch"],
+ patch_file = ["//third_party/grpc:upb_platform_fix.patch", "//third_party/grpc:upb_gcc10_compile_fix.patch"],
urls = tf_mirror_urls("https://github.com/protocolbuffers/upb/archive/9effcbcb27f0a665f9f345030188c0b291e32482.tar.gz"),
)
new file mode 100644
@@ -0,0 +1,34 @@
+licenses(["notice"])
+
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+config_setting(
+ name = "enable_kdnn",
+ define_values = {
+ "enable_kdnn": "true",
+ },
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "kdnn",
+ hdrs = glob(["include/**/*.hpp"]),
+ includes = ["include"],
+ srcs = glob(["src/libkdnn.a"]),
+ linkopts = ["-lgomp"],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "kdnn_adapter",
+ hdrs = ["kdnn_adapter.h", "kdnn_threadpool.h", "kdnn_types_adapter.h", "kdnn_layout_adapter.h"],
+ strip_include_prefix = "/third_party/KDNN",
+ visibility = ["//visibility:public"],
+ deps = [":kdnn"],
+)
+
+bzl_library(
+ name = "build_defs_bzl",
+ srcs = ["build_defs.bzl"],
+ visibility = ["//visibility:public"],
+)
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,24 @@
+def if_enable_kdnn(if_true, if_false = []):
+ """Shorthand to select() if we are building with KDNN and KDNN is enabled.
+
+ This is only effective when built with KDNN.
+
+ Args:
+ if_true: expression to evaluate if building with KDNN and KDNN is enabled
+ if_false: expression to evaluate if building without KDNN or KDNN is not enabled.
+
+ Returns:
+ A select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ "@org_tensorflow//third_party/KDNN:enable_kdnn": if_true,
+ "//conditions:default": if_false,
+ })
+
+def kdnn_deps():
+ """Shorthand for select() to pull in the correct set of KDNN library deps.
+ """
+ return select({
+ "@org_tensorflow//third_party/KDNN:enable_kdnn": ["//third_party/KDNN:kdnn_adapter"],
+ "//conditions:default": [],
+ })
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,289 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_ADAPTER_H_
+#include "kdnn.hpp"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/matmul_bcast.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "kdnn_threadpool.h"
+#include "kdnn_types_adapter.h"
+#include "kdnn_layout_adapter.h"
+#include "tensorflow/core/util/port.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "operations/kdnn_softmax.hpp"
+#include "operations/kdnn_eltwise.hpp"
+
+namespace tensorflow {
+
+inline void kdnnFusedGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+ bool fusion_relu, bool trans_x, bool trans_y) {
+ int m = a.dim_size(0);
+ int n = b.dim_size(trans_y ? 0 : 1);
+ int k = b.dim_size(trans_y ? 1 : 0);
+ const float *A = a.flat<float>().data();
+ const float *B = b.flat<float>().data();
+ float *C = out->flat<float>().data();
+ const Tensor& bias = ctx->input(2);
+ const float *Bias = bias.flat<float>().data();
+ if (bias.dims() != 1 || bias.dim_size(0) != n) {
+ OP_REQUIRES_OK(ctx, errors::InvalidArgument("bias must be 1-dimensional and match n",
+ bias.shape().DebugString()));
+ }
+ KDNN::PostOpsDataPtrs po_ptrs;
+ KDNN::PostOps post_ops;
+ if (fusion_relu) {
+ post_ops.AppendEltwise(KDNN::ActivationFunction::RELU);
+ po_ptrs.push_back(&post_ops);
+ }
+ // intra_op thread_pool
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, trans_y ? KDNN::Layout::BA : KDNN::Layout::AB};
+ const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ const KDNN::TensorInfo biasInfo = {{1, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ KDNN::Attributes attr;
+ attr.SetPostOps(post_ops);
+ KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, biasInfo, attr);
+ gemm.Run(A, B, C, Bias, po_ptrs);
+ KDNN::Threading::DeactivateThreadpool();
+}
+
+template<typename Tindices>
+inline void kdnnSparseMatmul(const std::size_t nnz,
+ const std::size_t rhs_right, const std::size_t lhs_right,
+ const int lhs_index_a, const int rhs_index_a,
+ typename TTypes<float>::Matrix out,
+ typename TTypes<Tindices>::ConstMatrix a_indices,
+ typename TTypes<float>::ConstVec a_values,
+ const float* b_data) {
+ std::vector<int> idx(nnz);
+ int lhs_left = out.dimension(0);
+ std::vector<int> pntrb(lhs_left);
+ std::vector<int> pntre(lhs_left);
+ std::vector<int> row_counts(lhs_left);
+ for (size_t i = 0; i < nnz; ++i) {
+ idx[i] = a_indices(i, rhs_index_a);
+ ++row_counts[a_indices(i, lhs_index_a)];
+ }
+
+ int current_pos = 0;
+ for (size_t i = 0; i < lhs_left; ++i) {
+ pntrb[i] = current_pos;
+ current_pos += row_counts[i];
+ pntre[i] = current_pos;
+ }
+ const KDNN::CsrSparseTensorInfo aInfo = {{lhs_left, lhs_right},
+ KDNN::Element::TypeT::F32, KDNN::Layout::AB, pntrb, pntre, idx, nnz};
+ const KDNN::TensorInfo bInfo = {{lhs_right, rhs_right},
+ KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ const KDNN::TensorInfo dstInfo = {{lhs_left, rhs_right},
+ KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ KDNN::SparseGemm sparse_csr(aInfo, bInfo, dstInfo);
+ sparse_csr.Run(a_values.data(), b_data, out.data());
+}
+
+template<typename T>
+inline void KDNNConcatImpl(OpKernelContext* ctx,
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ KDNN::Element::TypeT kdnnType = KDNN::Element::TypeAdapter<T>::value;
+ KDNN::Layout kdnnLayout = KDNN::LayoutAdapter<2, false>::value;
+ OP_REQUIRES(ctx, kdnnType != KDNN::Element::TypeT::UNDEFINED,
+ errors::InvalidArgument("unsupported kdnn data type"));
+ OP_REQUIRES(ctx, kdnnLayout != KDNN::Layout::UNDEFINED,
+ errors::InvalidArgument("unsupported kdnn layout"));
+ std::vector<KDNN::TensorInfo> inputInfos;
+ std::vector<const void *> input_ptrs;
+ inputInfos.reserve(inputs.size());
+ input_ptrs.reserve(inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ auto dim0 = inputs[i]->dimension(0);
+ auto dim1 = inputs[i]->dimension(1);
+ inputInfos.emplace_back(KDNN::TensorInfo{{dim0, dim1}, kdnnType, kdnnLayout});
+ input_ptrs.push_back(static_cast<const void*>(inputs[i]->data()));
+ }
+ void* output_ptr = static_cast<void *>(output->data());
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ KDNN::TensorInfo outputInfo({output->dimension(0), output->dimension(1)}, kdnnType, kdnnLayout);
+ KDNN::ConcatLayer concat(inputInfos, 1, outputInfo);
+ concat.Run(input_ptrs.data(), output_ptr);
+ KDNN::Threading::DeactivateThreadpool();
+}
+
+inline KDNN::TensorInfo MakeInfo(const tensorflow::Tensor* tensor, bool transposed) {
+ const tensorflow::TensorShape& shape = tensor->shape();
+ int dims = shape.dims();
+
+ std::vector<int64_t> d5 = {1, 1, 1, 1, 1};
+ for (int i = 0; i < dims; ++i) {
+ d5[4 - i] = shape.dim_size(dims - 1 - i);
+ }
+
+ if (transposed) {
+ std::swap(d5[3], d5[4]);
+ }
+
+ return KDNN::TensorInfo(
+ {d5[0], d5[1], d5[2], d5[3], d5[4]},
+ KDNN::Element::TypeT::F32,
+ transposed ? KDNN::Layout::ABCED : KDNN::Layout::ABCDE
+ );
+}
+
+inline KDNN::TensorInfo MakeOutputInfo(const KDNN::TensorInfo &tensorA, const KDNN::TensorInfo &tensorB) {
+ int dims = tensorA.GetNumDims();
+ std::vector<int64_t> d5 = {1, 1, 1, 1, 1};
+ for (int i = 0; i < dims - 2; ++i) {
+ d5[i] = std::max(tensorA.GetDims()[i], tensorB.GetDims()[i]);
+ }
+ d5[3] = tensorA.GetDims()[3];
+ d5[4] = tensorB.GetDims()[4];
+ return KDNN::TensorInfo(
+ {d5[0], d5[1], d5[2], d5[3], d5[4]},
+ KDNN::Element::TypeT::F32, KDNN::Layout::ABCDE
+ );
+}
+
+inline void kdnnGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+ bool trans_x, bool trans_y) {
+ int m = a.dim_size(trans_x ? 2 : 1);
+ int n = b.dim_size(trans_y ? 1 : 2);
+ int k = b.dim_size(trans_y ? 2 : 1);
+ const float *A = a.flat<float>().data();
+ const float *B = b.flat<float>().data();
+ float *C = out->flat<float>().data();
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, trans_x ? KDNN::Layout::BA : KDNN::Layout::AB};
+ const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, trans_y ? KDNN::Layout::BA : KDNN::Layout::AB};
+ const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB};
+ KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo);
+ gemm.Run(A, B, C);
+ KDNN::Threading::DeactivateThreadpool();
+}
+
+inline void kdnnBatchGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out,
+ bool trans_x, bool trans_y) {
+ const float *A = a.flat<float>().data();
+ const float *B = b.flat<float>().data();
+ float *C = out->flat<float>().data();
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ const KDNN::TensorInfo srcInfo = MakeInfo(&a, trans_x);
+ const KDNN::TensorInfo weightsInfo = MakeInfo(&b, trans_y);
+ const KDNN::TensorInfo dstInfo = MakeOutputInfo(srcInfo, weightsInfo);
+ KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo);
+ gemm.Run(A, B, C);
+ KDNN::Threading::DeactivateThreadpool();
+}
+
+template <typename Functor>
+inline void kdnnFloormodOp(OpKernelContext* ctx, const Tensor &input_0, const Tensor &input_1, Tensor *output) {
+ typedef typename Functor::in_type Tin; // Input scalar data type.
+ const Tin* src = input_0.flat<Tin>().data();
+ const Tin* src_1 = input_1.flat<Tin>().data();
+ Tin* dst = output->flat<Tin>().data();
+
+ KDNN::Shape tensorShape({input_0.shape().num_elements()});
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+
+ if (std::is_same<Tin, int64_t>::value) {
+ KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+ KDNN::TensorInfo inputTensorInfo_1(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+ KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::S64, KDNN::Layout::A);
+ KDNN::BinaryLayer layer(inputTensorInfo, inputTensorInfo_1, outputTensorInfo, KDNN::BinaryFunction::FLOORMOD);
+ layer.Run(src, src_1, dst);
+ } else {
+ KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+ KDNN::TensorInfo inputTensorInfo_1(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+ KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+ KDNN::BinaryLayer layer(inputTensorInfo, inputTensorInfo_1, outputTensorInfo, KDNN::BinaryFunction::FLOORMOD);
+ layer.Run(src, src_1, dst);
+ }
+
+ KDNN::Threading::DeactivateThreadpool();
+ return;
+}
+
+template <typename Functor>
+inline void kdnnSigmoidOp(OpKernelContext* ctx, const Tensor &input, Tensor *output)
+{
+ typedef typename Functor::in_type Tin;
+ const Tin* src = input.flat<Tin>().data();
+ Tin* dst = output->flat<Tin>().data();
+ KDNN::Shape tensorShape({input.shape().num_elements()});
+ KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+ KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::A);
+
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ KDNN::ActivationLayerFWD layer(inputTensorInfo, outputTensorInfo, KDNN::ActivationFunction::SIGMOID);
+ layer.Run(src, dst);
+ KDNN::Threading::DeactivateThreadpool();
+ return;
+}
+
+template <typename T>
+inline void kdnnSoftmaxOp(OpKernelContext* ctx, const Tensor &input, Tensor *output)
+{
+ const T* src = input.flat_inner_dims<T>().data();
+ T* dst = output->flat_inner_dims<T>().data();
+ KDNN::Shape tensorShape({input.flat_inner_dims<T>().dimension(0), input.flat_inner_dims<T>().dimension(1)});
+ KDNN::TensorInfo inputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::AB);
+ KDNN::TensorInfo outputTensorInfo(tensorShape, KDNN::Element::TypeT::F32, KDNN::Layout::AB);
+
+ thread::ThreadPool* thread_pool =
+ ctx->device()
+ ->tensorflow_cpu_worker_threads()
+ ->workers;
+ kdnn::KDNNThreadPool kdnn_tp(thread_pool);
+ KDNN::Threading::ActivateThreadpool(&kdnn_tp);
+ KDNN::SoftmaxLayerFWD layer(inputTensorInfo, outputTensorInfo, 1, KDNN::AlgorithmKind::SOFTMAX);
+ layer.Run(src, dst);
+ KDNN::Threading::DeactivateThreadpool();
+ return;
+}
+
+}// namespace tensorflow
+#endif
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,42 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H_
+#include "kdnn.hpp"
+
+namespace KDNN {
+
+template <int Rank, bool Transposed = false>
+struct LayoutAdapter {
+ static constexpr Layout value = Layout::UNDEFINED;
+};
+
+template <>
+struct LayoutAdapter<1, false> {
+ static constexpr Layout value = Layout::A;
+};
+
+template <>
+struct LayoutAdapter<2, false> {
+ static constexpr Layout value = Layout::AB;
+};
+
+template <>
+struct LayoutAdapter<2, true> {
+ static constexpr Layout value = Layout::BA;
+};
+} // KDNN
+#endif // TENSORFLOW_CORE_UTIL_KDNN_LAYOUT_ADAPTER_H
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,72 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#define EIGEN_USE_THREADS
+
+#include "kdnn.hpp"
+#include "tensorflow/core/platform/blocking_counter.h"
+#include "tensorflow/core/platform/threadpool.h"
+
+namespace kdnn {
+
+using tensorflow::thread::ThreadPool;
+
+class KDNNThreadPool : public KDNN::Threading::ThreadpoolIface {
+ public:
+ KDNNThreadPool() = default;
+
+ KDNNThreadPool(ThreadPool* thread_pool,
+ int num_threads = -1)
+ : thread_pool_(thread_pool),
+ eigen_interface_(thread_pool->AsEigenThreadPool()) {
+ set_num_and_max_threads(num_threads);
+ }
+
+ int GetNumThreads() const override { return num_threads_; }
+
+ void ParallelFor(int n, int64_t cost_per_unit,
+ const std::function<void(int, int)>& fn) override {
+ thread_pool_->ParallelFor(n, cost_per_unit, fn);
+ }
+
+ bool IsInParallel() const override {
+ return eigen_interface_->CurrentThreadId() != -1;
+ }
+
+ ~KDNNThreadPool() {}
+
+ private:
+ ThreadPool* thread_pool_ = nullptr;
+ Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
+ int num_threads_ = 1;
+ inline void set_num_and_max_threads(int num_threads) {
+ num_threads_ =
+ num_threads == -1 ? eigen_interface_->NumThreads() : num_threads;
+ }
+};
+
+} // namespace kdnn
+
+#endif // TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_
new file mode 100644
@@ -0,0 +1,60 @@
+/* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
+#define TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
+#include "kdnn.hpp"
+
+namespace KDNN {
+namespace Element {
+
+template <typename T>
+struct TypeAdapter {
+ static constexpr TypeT value = TypeT::UNDEFINED;
+};
+
+template <>
+struct TypeAdapter<float> {
+ static constexpr TypeT value = TypeT::F32;
+};
+
+template <>
+struct TypeAdapter<Eigen::half> {
+ static constexpr TypeT value = TypeT::F16;
+};
+
+template <>
+struct TypeAdapter<tensorflow::bfloat16> {
+ static constexpr TypeT value = TypeT::BF16;
+};
+
+template <>
+struct TypeAdapter<int32_t> {
+ static constexpr TypeT value = TypeT::S32;
+};
+
+template <>
+struct TypeAdapter<int8_t> {
+ static constexpr TypeT value = TypeT::S8;
+};
+
+template <>
+struct TypeAdapter<uint8_t> {
+ static constexpr TypeT value = TypeT::U8;
+};
+} // Element
+} // KDNN
+
+#endif // TENSORFLOW_CORE_UTIL_KDNN_TYPES_ADAPTER_H_
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,137 @@
+diff -ruN include/service/kdnn_service.hpp include_bak/service/kdnn_service.hpp
+--- include/service/kdnn_service.hpp 2025-10-25 18:55:04.018830570 +0800
+@@ -107,7 +107,7 @@
+ T *allocate(SizeType n) const noexcept(false)
+ {
+ if (n > std::numeric_limits<SizeType>::max() / sizeof(T)) {
+- throw BadArrayNewLength();
++ return nullptr;
+ }
+ return static_cast<T*>(AlignedAlloc(n * sizeof(T), alignment));
+ }
+diff -ruN include/types/kdnn_data_type.hpp include_bak/types/kdnn_data_type.hpp
+--- include/types/kdnn_data_type.hpp 2025-10-25 18:55:04.018830570 +0800
+@@ -65,9 +65,6 @@
+ }
+ default: {}
+ }
+- if (type == TypeT::UNDEFINED) {
+- throw Service::LogicError {"Type: unsupported data type"};
+- }
+ }
+ SizeType GetSize() const noexcept
+ {
+diff -ruN include/types/kdnn_shape.hpp include_bak/types/kdnn_shape.hpp
+--- include/types/kdnn_shape.hpp 2025-10-25 18:55:04.018830570 +0800
+@@ -36,7 +36,7 @@
+ Shape(T *ptr, const SizeType size) noexcept(false) : numDims(size)
+ {
+ if (ptr == nullptr) {
+- throw Service::LogicError("Shape: ptr is nullptr");
++ return;
+ }
+ CheckNumDims(numDims);
+ for (SizeType i = 0; i < numDims; ++i) {
+@@ -83,7 +83,7 @@
+ Shape& ResetShape(T *ptr, const SizeType size) noexcept(false)
+ {
+ if (ptr == nullptr) {
+- throw Service::LogicError("Shape: ptr is nullptr");
++ return *this;
+ }
+ CheckNumDims(size);
+ numDims = size;
+@@ -99,7 +99,7 @@
+ Shape& operator+=(const Shape &adder) noexcept(false)
+ {
+ if (adder.GetNumDims() != this->GetNumDims()) {
+- throw Service::LogicError("Shape: different size of base and adder shapes");
++ return *this;
+ }
+ for (SizeType i = 0; i < adder.GetNumDims(); ++i) {
+ this->operator[](i) += adder[i];
+@@ -109,9 +109,6 @@
+
+ Shape operator+(const Shape &adder) const noexcept(false)
+ {
+- if (adder.GetNumDims() != this->GetNumDims()) {
+- throw Service::LogicError("Shape: different size of base and adder shapes");
+- }
+ std::array<SizeType, NUM_MAX_DIMENSIONS> tmp;
+ for (SizeType i = 0; i < adder.GetNumDims(); ++i) {
+ tmp[i] = this->operator[](i) + adder[i];
+@@ -121,17 +118,11 @@
+
+ SizeType operator[](SizeType id) const noexcept(false)
+ {
+- if (id >= numDims) {
+- throw Service::LogicError("Shape: index >= num_dims");
+- }
+ return dimsArray[id];
+ }
+
+ SizeType& operator[](SizeType id) noexcept(false)
+ {
+- if (id >= numDims) {
+- throw Service::LogicError("Shape: index >= num_dims");
+- }
+ return dimsArray[id];
+ }
+
+@@ -142,9 +133,6 @@
+
+ SizeType GetTotalDimsSize() const noexcept(false)
+ {
+- if (Service::WillIntMultOverflow(dimsArray.begin(), dimsArray.begin() + numDims)) {
+- throw Service::LogicError("Shape: computing total size will cause overflow");
+- }
+ SizeType accum = 1;
+ for (SizeType i = 0; i < numDims; ++i) {
+ accum *= dimsArray[i];
+@@ -153,12 +141,7 @@
+ }
+
+ private:
+- void CheckNumDims(SizeType nDims) const noexcept(false)
+- {
+- if (nDims > NUM_MAX_DIMENSIONS) {
+- throw Service::LogicError("Shape: dims is greater than NUM_MAX_DIMENSIONS");
+- }
+- }
++ void CheckNumDims(SizeType nDims) const noexcept(false) {}
+ std::array<SizeType, NUM_MAX_DIMENSIONS> dimsArray;
+ SizeType numDims;
+ };
+diff -ruN include/types/kdnn_tensor_info.hpp include_bak/types/kdnn_tensor_info.hpp
+--- include/types/kdnn_tensor_info.hpp 2025-10-25 18:55:04.018830570 +0800
+@@ -214,7 +214,7 @@
+ return Layout::ABCDE;
+ }
+ default: {
+- throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++ return Layout::UNDEFINED;
+ }
+ }
+ }
+@@ -238,7 +238,7 @@
+ return Layout::ACDEB;
+ }
+ default: {
+- throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++ return Layout::UNDEFINED;
+ }
+ }
+ }
+@@ -262,7 +262,7 @@
+ return Layout::BCDEA;
+ }
+ default: {
+- throw Service::LogicError {"Tensor Info: tensor dimensionality is incorrect"};
++ return Layout::UNDEFINED;
+ }
+ }
+ }
new file mode 100644
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+ name = "build_with_fused_embedding",
+ define_values = {
+ "build_with_fused_embedding": "true",
+ },
+)
+
+bzl_library(
+ name = "build_defs_bzl",
+ srcs = ["build_defs.bzl"],
+)
new file mode 100644
@@ -0,0 +1,8 @@
+"""Starlark macros for fused_embedding.
+"""
+
+def if_fused_embedding(if_true, if_false = []):
+ return select({
+ "@org_tensorflow//third_party/fused_embedding:build_with_fused_embedding": if_true,
+ "//conditions:default": if_false,
+ })
new file mode 100644
@@ -0,0 +1,11 @@
+--- a/upb/upb.c 2025-05-30 17:01:35.956845750 +0800
+@@ -37,7 +37,7 @@
+ void upb_status_seterrmsg(upb_status *status, const char *msg) {
+ if (!status) return;
+ status->ok = false;
+- strncpy(status->msg, msg, sizeof(status->msg));
++ strncpy(status->msg, msg, sizeof(status->msg) - 1);
+ nullz(status);
+ }
+
new file mode 100644
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+ name = "build_with_kblas",
+ define_values = {
+ "build_with_kblas": "true",
+ },
+)
+
+bzl_library(
+ name = "build_defs_bzl",
+ srcs = ["build_defs.bzl"],
+)
new file mode 100644
@@ -0,0 +1,8 @@
+"""Starlark macros for kblas.
+"""
+
+def if_kblas(if_true, if_false = []):
+ return select({
+ "@org_tensorflow//third_party/kblas:build_with_kblas": if_true,
+ "//conditions:default": if_false,
+ })
new file mode 100644
@@ -0,0 +1,12 @@
+cc_import(
+ name = "kblas_so",
+ shared_library = "lib/sve/kblas/locking/libkblas.so",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "kblas",
+ hdrs = ["include/kblas.h"],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+)
\ No newline at end of file
new file mode 100644
@@ -0,0 +1,10 @@
+"""Provides the repository macro to import kblas."""
+
+def repo():
+ """Imports kblas."""
+
+ native.new_local_repository(
+ name = "kblas_archive",
+ build_file = "@org_tensorflow//third_party/kblas:kblas.BUILD",
+ path = "/usr/local/kml",
+ )
new file mode 100644
@@ -0,0 +1,21 @@
+load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
+
+package(
+ # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+exports_files(["LICENSE"])
+
+config_setting(
+ name = "build_with_ktfop",
+ define_values = {
+ "build_with_ktfop": "true",
+ },
+)
+
+bzl_library(
+ name = "build_defs_bzl",
+ srcs = ["build_defs.bzl"],
+)
new file mode 100644
@@ -0,0 +1,8 @@
+"""Starlark macros for ktfop.
+"""
+
+def if_ktfop(if_true, if_false = []):
+ return select({
+ "@org_tensorflow//third_party/ktfop:build_with_ktfop": if_true,
+ "//conditions:default": if_false,
+ })
new file mode 100644
@@ -0,0 +1,14 @@
+cc_import(
+ name = "ktfop_so",
+ shared_library = "lib/sve/libktfop.so",
+ deps = ["@kblas_archive//:kblas_so"],
+)
+
+cc_library(
+ name = "ktfop",
+ hdrs = ["include/ktfop.h"],
+ includes = ["include"],
+ deps = [":ktfop_so",
+ "@kblas_archive//:kblas"],
+ visibility = ["//visibility:public"],
+)
new file mode 100644
@@ -0,0 +1,10 @@
+"""Provides the repository macro to import ktfop."""
+
+def repo():
+ """Imports ktfop."""
+
+ native.new_local_repository(
+ name = "ktfop_archive",
+ build_file = "@org_tensorflow//third_party/ktfop:ktfop.BUILD",
+ path = "/usr/local/sra_inference",
+ )
@@ -44,6 +44,8 @@ limitations under the License.
#include "tsl/platform/strcat.h"
#include "tsl/protobuf/error_codes.pb.h"
+#define GPR_ARRAY_SIZE(array) (sizeof(array) / sizeof(*(array)))
+
namespace tsl {
namespace {
@@ -86,6 +88,16 @@ class PThread : public Thread {
static void* ThreadFn(void* params_arg) {
std::unique_ptr<ThreadParams> params(
reinterpret_cast<ThreadParams*>(params_arg));
+
+ if (!params->name.empty()) {
+ /* Linux supports 16 characters max, and will
+ * error if it's longer. */
+ char buf[16];
+ size_t buf_len = GPR_ARRAY_SIZE(buf) - 1;
+ strncpy(buf, params->name.c_str(), buf_len);
+ buf[buf_len] = '\0';
+ pthread_setname_np(pthread_self(), buf);
+ }
{
mutex_lock l(name_mutex);
GetThreadNameRegistry().emplace(std::this_thread::get_id(), params->name);
@@ -32,6 +32,11 @@ int NUMANumNodes();
static const int kNUMANoAffinity = -1;
+enum ThreadAffinity {
+ OFF,
+ ORDER,
+ INTERVAL
+};
// If possible sets affinity of the current thread to the specified NUMA node.
// If node == kNUMANoAffinity removes affinity to any particular node.
void NUMASetThreadNodeAffinity(int node);
@@ -131,6 +131,13 @@ ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
+ExecutableRunOptions& ExecutableRunOptions::set_run_in_tf_kernel(bool run_in_tf_kernel) {
+ run_in_tf_kernel_ = run_in_tf_kernel;
+ return *this;
+}
+
+bool ExecutableRunOptions::run_in_tf_kernel() const { return run_in_tf_kernel_; }
+
ExecutableRunOptions& ExecutableRunOptions::set_run_id(RunId id) {
run_id_ = id;
return *this;
@@ -171,6 +171,10 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_rng_seed(int rng_seed);
int rng_seed() const;
+ ExecutableRunOptions& set_run_in_tf_kernel(bool run_in_tf_kernel);
+
+ bool run_in_tf_kernel() const;
+
ExecutableRunOptions& set_launch_id(int32_t launch_id) {
launch_id_ = launch_id;
return *this;
@@ -224,6 +228,7 @@ class ExecutableRunOptions {
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
+ bool run_in_tf_kernel_ = false;
int32_t launch_id_ = 0;
stream_executor::Stream* device_to_host_stream_ = nullptr;
stream_executor::Stream* host_to_device_stream_ = nullptr;
@@ -638,33 +638,38 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
ExecutionOutput result,
CreateResultShapedBuffer(run_options, absl::MakeSpan(buffers),
absl::MakeSpan(arguments)));
-
- // Logically we want this lambda to capture `buffers` by move, ultimately our
- // functor needs to be wrapped in an std::function, and that requires its
- // functor to be copyable. Thus we perpetrate the hack of capturing buffers
- // "by shared pointer".
- //
- // We also need to change the types of some of the variables we capture:
- // run_options needs to change from a pointer to a value type, and arguments
- // needs to change from a Span into a vector. We use a struct instead
- // of a lambda to make this explicit.
- struct AsyncRunTask {
- CpuExecutable* executable;
- ServiceExecutableRunOptions run_options;
- std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
- HloExecutionProfile* hlo_execution_profile;
-
- Status operator()() {
- return executable->ExecuteComputeFunction(
- &run_options.run_options(), *task_buffers, hlo_execution_profile);
- }
- };
- host_stream->EnqueueTaskWithStatus(
- AsyncRunTask{this, *run_options,
- std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
- std::move(buffers)),
- hlo_execution_profile});
-
+ if (run_options->run_options().run_in_tf_kernel()) {
+ std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers =
+ std::make_shared<std::vector<MaybeOwningDeviceMemory>>(std::move(buffers));
+ (void)this->ExecuteComputeFunction(
+ &run_options->run_options(), *task_buffers, hlo_execution_profile);
+ } else {
+ // Logically we want this lambda to capture `buffers` by move, ultimately our
+ // functor needs to be wrapped in an std::function, and that requires its
+ // functor to be copyable. Thus we perpetrate the hack of capturing buffers
+ // "by shared pointer".
+ //
+ // We also need to change the types of some of the variables we capture:
+ // run_options needs to change from a pointer to a value type, and arguments
+ // needs to change from a Span into a vector. We use a struct instead
+ // of a lambda to make this explicit.
+ struct AsyncRunTask {
+ CpuExecutable* executable;
+ ServiceExecutableRunOptions run_options;
+ std::shared_ptr<std::vector<MaybeOwningDeviceMemory>> task_buffers;
+ HloExecutionProfile* hlo_execution_profile;
+
+ Status operator()() {
+ return executable->ExecuteComputeFunction(
+ &run_options.run_options(), *task_buffers, hlo_execution_profile);
+ }
+ };
+ host_stream->EnqueueTaskWithStatus(
+ AsyncRunTask{this, *run_options,
+ std::make_shared<std::vector<MaybeOwningDeviceMemory>>(
+ std::move(buffers)),
+ hlo_execution_profile});
+ }
MarkToBeReleasedArguments(absl::MakeSpan(arguments), result);
return std::move(result);
}
new file mode 100644
@@ -0,0 +1,9 @@
+export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
+export TF_NEED_CLANG=0
+export CC_OPT_FLAGS='-march=armv8.3-a+crc'
+
+export PYTHON_BIN_PATH=$(which python)
+yes "" | $PYTHON_BIN_PATH configure.py
+
+bazel --output_user_root=./test_output test --distdir=../serving/download --test_tag_filters=-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test --test_lang_filters=cc,java -k --test_timeout 300,450,1200,3600 --config=opt --test_output=errors --test_size_filters=small,medium,large --build_tests_only -- //tensorflow/core/... //tensorflow/compiler/jit/... -//tensorflow/core/tpu/...
\ No newline at end of file