* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "eigen_threadpool_embedding.h"
#include <sys/sysinfo.h>
#include "log.h"
#include "mmpa/mmpa_api.h"
namespace {
const uint32_t kTaskSize = 40000;
const uint32_t kMaxOverShardingFactor = 4;
const int32_t kHalfCoresNum = 2;
const int32_t kPerformanceMulitCoresNum = 48;
const int32_t kPerformanceMulitCoresNumForLargeCores = 72;
const uint32_t kDecimalScaleNum = 10;
const uint32_t kTotalCostFactor = 210000;
constexpr uint32_t kMaxTaskSize = kTaskSize * kMaxOverShardingFactor;
}
namespace aicpu {
std::mutex EigenThreadPoolEmbedding::mutex_;
bool EigenThreadPoolEmbedding::init_flag_(false);
int32_t EigenThreadPoolEmbedding::core_num_(0);
std::unique_ptr<Eigen::ThreadPool> EigenThreadPoolEmbedding::eigen_threadpool_(nullptr);
std::unique_ptr<Eigen::ThreadPoolDevice> EigenThreadPoolEmbedding::threadpool_device_(nullptr);
EigenThreadPoolEmbedding* EigenThreadPoolEmbedding::GetInstance()
{
if (!init_flag_) {
std::lock_guard<std::mutex> lock(mutex_);
if (!init_flag_) {
core_num_ = get_nprocs() / kHalfCoresNum;
core_num_ = (core_num_ > kPerformanceMulitCoresNumForLargeCores) ? kPerformanceMulitCoresNumForLargeCores :
core_num_;
const char* core_num_value = nullptr;
MM_SYS_GET_ENV(MM_ENV_EMBEDDING_MAX_THREAD_CORE_NUMBER, core_num_value);
if ((core_num_value != nullptr) && (core_num_value[0U] != '\0')) {
try {
core_num_ = std::strtol(&(core_num_value[0U]), nullptr, kDecimalScaleNum);
} catch (...) {
KERNEL_LOG_WARN("embedding thread can not find core num env, use default value.");
core_num_ = kPerformanceMulitCoresNum;
}
}
if (core_num_ <= 0 || core_num_ > get_nprocs()) {
core_num_ = kPerformanceMulitCoresNum;
}
eigen_threadpool_.reset(new Eigen::ThreadPool(core_num_));
threadpool_device_.reset(new Eigen::ThreadPoolDevice(eigen_threadpool_.get(), core_num_));
init_flag_ = true;
KERNEL_LOG_INFO("Init embedding thread pool success, core num[%d]", core_num_);
}
}
static EigenThreadPoolEmbedding instance;
return &instance;
}
void EigenThreadPoolEmbedding::ParallelFor(int64_t total, int64_t per_unit_size, const SharderWork& work) const
{
double per_unit_cost = 1.0;
if ((total <= 0) || (work == nullptr) || (per_unit_size <= 0)) {
KERNEL_LOG_WARN(
"Invalid param: total[%ld] <= 0 or per_unit_size[%ld] <= 0 or work is nullptr", total, per_unit_size);
return;
}
if ((per_unit_size) <= (total / core_num_)) {
per_unit_cost = (1.0 * kMaxTaskSize * core_num_ / total) > (1.0 * kTotalCostFactor / total) ?
(1.0 * kMaxTaskSize * core_num_ / total) :
(1.0 * kTotalCostFactor / total);
} else {
per_unit_cost = 1.0 * kMaxTaskSize / per_unit_size;
}
threadpool_device_->parallelFor(
total, Eigen::TensorOpCost(0, 0, per_unit_cost),
[&work](Eigen::Index first, Eigen::Index last) { work(first, last); });
KERNEL_LOG_INFO("Eigen threadpool parallel for success.");
}
}