* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <string>
#include <sstream>
#include <memory>
#include <cstring>
#include "alg_param.h"
#include "executor_base.h"
#include "coll_alg_exec_registry.h"
#include "coll_alg_v2_exec_registry.h"
#include "hcomm_primitives.h"
#include "hcomm_primitives_dl.h"
#include "dfx/task_exception_fun.h"
#include "kernel_launch.h"
#include "hcomm_diag_dl.h"
#include "hcomm_device_profiling_dl.h"
#include <unordered_map>
#include <shared_mutex>
#include <atomic>
#include "hccl_device_comm_dl.h"
#include "exec_timeout_manager.h"
#include "alg_data_trans_wrapper.h"
using namespace ops_hccl;
namespace {
struct CacheStats {
std::atomic<uint64_t> hits{0};
std::atomic<uint64_t> misses{0};
double hitRate() const {
uint64_t total = hits + misses;
return total > 0 ? static_cast<double>(hits) / total : 0.0;
}
void Reset() {
hits = 0;
misses = 0;
}
};
class CommDomainCache {
public:
explicit CommDomainCache(const std::string& commName) : commName_(commName) {}
const std::string& GetCommName() const {return commName_; }
std::shared_ptr<const AlgResourceCtxSerializable> Get(const std::string& algTag) {
std::shared_lock<std::shared_timed_mutex> lock(mutex_);
auto it = cache_.find(algTag);
return it != cache_.end() ? it->second : nullptr;
}
void Put(const std::string& algTag, const AlgResourceCtxSerializable& value) {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
cache_[algTag] = std::make_shared<AlgResourceCtxSerializable>(value);
}
bool Remove(const std::string& algTag) {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
return cache_.erase(algTag) > 0;
}
void Clear() {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
cache_.clear();
}
CacheStats& GetStats() { return stats_; }
const CacheStats& GetStats() const { return stats_; }
size_t GetCacheSize() const {
std::shared_lock<std::shared_timed_mutex> lock(mutex_);
return cache_.size();
}
private:
std::string commName_;
std::unordered_map<std::string, std::shared_ptr<const AlgResourceCtxSerializable>> cache_;
CacheStats stats_;
mutable std::shared_timed_mutex mutex_;
};
class CommDomainCacheManager {
public:
std::shared_ptr<const AlgResourceCtxSerializable> Get(const std::string& algTag, const std::string& paramCommName) {
std::string commName = ExtractCommName(algTag);
if (commName.empty()) commName = paramCommName;
CommDomainCache* commCache = GetOrCreateComm(commName);
if (commCache) {
auto& stats = commCache->GetStats();
auto result = commCache->Get(algTag);
if (result) {
stats.hits++;
return result;
}
stats.misses++;
}
return nullptr;
}
void Put(const std::string& algTag, const AlgResourceCtxSerializable& value, const std::string& paramCommName) {
std::string commName = ExtractCommName(algTag);
if (commName.empty()) commName = paramCommName;
CommDomainCache* commCache = GetOrCreateComm(commName);
if (commCache) {
commCache->Put(algTag, value);
}
}
bool ReleaseComm(const std::string& commName) {
std::unique_lock<std::shared_timed_mutex> lock(mapMutex_);
return commCaches_.erase(commName) > 0;
}
bool GetCommStats(const std::string& commName, CacheStats& outStats, size_t& outCacheSize) const {
std::shared_lock<std::shared_timed_mutex> lock(mapMutex_);
auto it = commCaches_.find(commName);
if (it != commCaches_.end()) {
outStats.hits = it->second.GetStats().hits.load();
outStats.misses = it->second.GetStats().misses.load();
outCacheSize = it->second.GetCacheSize();
return true;
}
return false;
}
void GetGlobalStats(size_t& totalCommDomains, size_t& totalcacheEntries, uint64_t& totalHits, uint64_t& totalMisses) const {
std::shared_lock<std::shared_timed_mutex> lock(mapMutex_);
totalCommDomains = commCaches_.size();
totalcacheEntries = 0;
totalHits = 0;
totalMisses = 0;
for (const auto& pair : commCaches_) {
const auto& commName = pair.first;
const auto& commCache = pair.second;
totalcacheEntries += commCache.GetCacheSize();
totalHits += commCache.GetStats().hits.load();
totalMisses += commCache.GetStats().misses.load();
}
}
void ClearAll() {
std::unique_lock<std::shared_timed_mutex> lock(mapMutex_);
commCaches_.clear();
}
std::string ExtractCommName(const std::string& algTag) {
size_t firstUnderscore = algTag.find('_');
if (firstUnderscore == std::string::npos) return "";
size_t secondUnderscore = algTag.find('_', firstUnderscore + 1);
if (secondUnderscore == std::string::npos) return "";
return algTag.substr(firstUnderscore+1, secondUnderscore - firstUnderscore - 1);
}
private:
CommDomainCache* GetOrCreateComm(const std::string& commName) {
{
std::shared_lock<std::shared_timed_mutex> lock(mapMutex_);
auto it = commCaches_.find(commName);
if (it != commCaches_.end()) {
return &it->second;
}
}
{
std::unique_lock<std::shared_timed_mutex> lock(mapMutex_);
auto it = commCaches_.find(commName);
if (it != commCaches_.end()) {
return &it->second;
}
auto result = commCaches_.emplace(
std::piecewise_construct,
std::forward_as_tuple(commName),
std::forward_as_tuple(commName)
);
return &result.first->second;
}
}
mutable std::shared_timed_mutex mapMutex_;
std::unordered_map<std::string, CommDomainCache> commCaches_;
};
thread_local CommDomainCacheManager g_cacheManager;
std::unique_ptr<AlgResourceCtxSerializable> DeserializeResCtx(const OpParam *param)
{
std::unique_ptr<AlgResourceCtxSerializable> resCtx = std::make_unique<AlgResourceCtxSerializable>();
char *ctx = static_cast<char *>(param->resCtx);
std::vector<char> seq(ctx, ctx + param->ctxSize);
resCtx->DeSerialize(seq);
return resCtx;
}
}
namespace ops_hccl {
bool IsOpsV2(const char* algName, DevType deviceType)
{
if (algName != nullptr) {
const char* prefix = "opv2_";
if (strncmp(algName, prefix, strlen(prefix)) == 0) {
return true;
}
}
#ifdef MACRO_DEV_TYPE_NEW
if (deviceType == DevType::DEV_TYPE_950) {
#else
if (deviceType == DevType::DEV_TYPE_910_95) {
#endif
return true;
}
return false;
}
}
extern "C" unsigned int HcclLaunchAicpuKernel(OpParam *param)
{
if (param == nullptr) {
HCCL_ERROR("%s param is nullptr", __func__);
return 1;
}
HCCL_INFO("Entry-%s, commName[%s], tag[%s], algTag[%s]", __func__, param->commName, param->tag, param->algTag);
if (HcommAcquireComm(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommAcquireComm fail, commName[%s]", __func__, param->commName);
return 1;
}
std::string algName = std::string(param->algName);
if (!ops_hccl::IsOpsV2(param->algName, param->deviceType)) {
ScatterOpInfo opInfo;
if (CreateScatter(param, &opInfo) != HCCL_SUCCESS) {
HCCL_ERROR("%s CreateScatter fail", __func__);
return 1;
}
if (HcommIsSupportHcommRegOpInfo() &&
HcommRegOpInfo(param->commName, reinterpret_cast<void *>(&opInfo), sizeof(ScatterOpInfo)) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommRegOpInfo fail, commName[%s], algTag[%s], size[%u]",
__func__, param->commName, opInfo.algTag, sizeof(ScatterOpInfo));
return 1;
}
if (HcommIsSupportHcommRegOpTaskException() &&
HcommRegOpTaskException(param->commName, ops_hccl::GetScatterOpInfo) != HCCL_SUCCESS) {
HCCL_ERROR(
"%s HcommRegOpTaskException fail, commName[%s], algTag[%s]", __func__, param->commName, param->algTag);
return 1;
}
}
if (ops_hccl::IsOpsV2(param->algName, param->deviceType)) {
HcclCommStatus commStatus = HCCL_COMM_STATUS_INVALID;
if (HcommIsSupportHcclCommGetStatus()) {
auto statusRet = HcclCommGetStatus(param->commName, &commStatus);
if (statusRet != HCCL_SUCCESS) {
HCCL_ERROR("%s HcclCommGetStatus fail, commName[%s], ret = %d", __func__, param->commName, statusRet);
return 1;
}
if (commStatus != HCCL_COMM_STATUS_READY) {
HCCL_ERROR("%s commStatus is not ready!, commStatus = %d", __func__, static_cast<int>(commStatus));
return 1;
}
}
std::shared_ptr<const AlgResourceCtxSerializable> cachedResCtxHolder;
std::unique_ptr<AlgResourceCtxSerializable> resCtx;
const AlgResourceCtxSerializable* resCtxPtr{nullptr};
u32 hitRateNum = 100;
if (param->opType != HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) {
cachedResCtxHolder = g_cacheManager.Get(param->algTag, param->commName);
if (cachedResCtxHolder != nullptr && IsResCtxCacheReusable(*cachedResCtxHolder, *param)) {
HCCL_INFO("[%s] Cache HIT for algTag[%s]", __func__, param->algTag);
std::string commName = g_cacheManager.ExtractCommName(param->algTag);
if (commName.empty()) commName = param->commName;
CacheStats stats;
size_t cacheSize;
if (g_cacheManager.GetCommStats(commName, stats, cacheSize)) {
HCCL_DEBUG("[%s] comm[%s] hitRate=%.2f%%, cacheSize=%zu",
__func__, commName.c_str(), stats.hitRate() * hitRateNum, cacheSize);
}
resCtxPtr = cachedResCtxHolder.get();
} else {
bool isStaleCache = (cachedResCtxHolder != nullptr);
resCtx = DeserializeResCtx(param);
g_cacheManager.Put(param->algTag, *resCtx, param->commName);
resCtxPtr = resCtx.get();
if (isStaleCache) {
HCCL_INFO("[%s] Cache STALE and refreshed for algTag[%s], cachedComm[%p], currentComm[%p]",
__func__, param->algTag, cachedResCtxHolder->commInfoPtr, param->hcclComm);
} else {
HCCL_INFO("[%s] Cache MISS and stored for algTag[%s]", __func__, param->algTag);
}
}
} else {
resCtx = DeserializeResCtx(param);
resCtxPtr = resCtx.get();
}
HcclResult ret = HCCL_SUCCESS;
if (param->opType == HCCL_CMD_BATCH_SEND_RECV) {
ret = ops_hccl::RestoreVarDataBatchSendRecv(*param);
} else if (param->opType == HCCL_CMD_ALLTOALLV || param->opType == HCCL_CMD_ALLTOALLVC ||
param->opType == HCCL_CMD_ALLTOALL) {
ret = ops_hccl::RestoreVarDataAlltoAllV(*param, *resCtxPtr);
} else if (param->opType == HCCL_CMD_REDUCE_SCATTER_V) {
ret = ops_hccl::RestoreVarDataReduceScatterV(*param, *resCtxPtr);
} else if (param->opType == HCCL_CMD_ALLGATHER_V) {
ret = ops_hccl::RestoreVarDataAllGatherV(*param, *resCtxPtr);
}
if (ret != HCCL_SUCCESS) {
HCCL_ERROR("failed to restore optype [%d] data and counts.", param->opType);
return 1;
}
ThreadHandle thread = resCtxPtr->threads[0];
if (HcommBatchModeStart(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set batch mode, tag is %s.", param->algTag);
return 1;
}
HcclDfxOpInfoCompat dfxOpInfo{};
if (ConvertToHcclDfxOpInfo(param, &dfxOpInfo) != HCCL_SUCCESS) {
HCCL_ERROR("ConvertToHcclDfxOpInfo fail, commName is %s, tag is %s", param->commName, param->algTag);
return 1;
}
if (HcclDfxRegOpInfoByCommId(param->commName, reinterpret_cast<void *>(&dfxOpInfo)) != HCCL_SUCCESS) {
HCCL_ERROR("HcclDfxRegOpInfoByCommId fail, commName is %s, tag is %s", param->commName, param->algTag);
return 1;
}
if (HcommProfilingReportKernelStartTask(thread, param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%sfailed to report MainStream And FirstTask, thread %lu, param->commName %s.", __func__, thread, param->commName);
return 1;
}
ThreadHandle exportedAicpuTsThread = param->opThread;
u32 maxNotifyNum = resCtxPtr->notifyNumOnMainThread;
for (u32 i = 0; i < resCtxPtr->notifyNumPerThread.size(); i++) {
if (resCtxPtr->notifyNumPerThread[i] > maxNotifyNum) {
maxNotifyNum = resCtxPtr->notifyNumPerThread[i];
}
}
if (HcommIsSupportHcommThreadResAcquireTimeOut()) {
CHK_RET(HcclThreadResAcquireTimeOut(resCtxPtr->fullTimeout));
}
if (HcommIsSupportHcommSetNotifyWaitTimeOut()) {
CHK_RET(HcclSetNotifyWaitTimeOut(resCtxPtr->waitTimeout));
}
HCCL_DEBUG("[%s]Notify wait on thread[%llu], maxNotifyNum[%u], timeout[%u]", __func__, thread,
maxNotifyNum, resCtxPtr->waitTimeout);
CHK_RET(HcclThreadNotifyWaitOnThreadDefault(thread, maxNotifyNum, resCtxPtr->waitTimeout));
std::shared_ptr<InsCollAlgBase> executor = CollAlgExecRegistryV2::Instance().GetAlgExec(param->opType, algName);
if (executor.get() == nullptr) {
HCCL_ERROR("Fail to find executor for algName[%s]", algName.c_str());
return 1;
}
ExecTimeoutManager::Instance().SetExecTimeout(param->opConfig.execTimeout);
CHK_RET(InitHcommBatchTransferOnThreadSupported(resCtxPtr->isHcommBatchTransferOnThreadSupported));
if (executor->Orchestrate(*param, *resCtxPtr) != HCCL_SUCCESS) {
HCCL_ERROR("orchestrate failed for alg:%s", param->algName);
return 1;
}
if (HcommProfilingReportKernelEndTask(thread, param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s failed to report MainStream And LastTask, thread %lu, param->commName %s.", __func__, thread, param->commName);
return 1;
}
constexpr u32 DEFAULT_NOTIFY_IDX = 0;
HCCL_DEBUG("[%s]Notify record on srcThread[%llu], dstThread[%llu], notifyIdx[%u]",__func__, thread, exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX);
CHK_RET(static_cast<HcclResult>(HcommThreadNotifyRecordOnThread(thread, exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX)));
if (HcommProfilingReportDeviceOp(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommProfilingReportDeviceOp fail, commName[%s]", __func__, param->commName);
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
} else {
std::unique_ptr<ExecutorBase> executor = CollAlgExecRegistry::Instance().GetAlgExec(algName);
if (executor.get() == nullptr) {
HCCL_ERROR("Fail to find executor for algName[%s]", algName.c_str());
return 1;
}
AlgResourceCtx *resCtx = reinterpret_cast<AlgResourceCtx *>(param->resCtx);
ThreadHandle *threadHandlePtr =
reinterpret_cast<ThreadHandle *>(reinterpret_cast<u8 *>(resCtx) + sizeof(AlgResourceCtx));
ThreadHandle thread = threadHandlePtr[0];
ThreadHandle exportedAicpuTsThread = resCtx->opThread;
u32 notifyNumOnMainThread = resCtx->notifyNumOnMainThread;
if (HcommBatchModeStart(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set batch mode, tag is %s.", param->algTag);
return 1;
}
if (exportedAicpuTsThread != 0) {
if (HcommProfilingInit(threadHandlePtr, resCtx->slaveThreadNum + 1) != HCCL_SUCCESS) {
HCCL_ERROR("failed to init Profiling");
return 1;
}
if (HcommProfilingReportMainStreamAndFirstTask(thread) != HCCL_SUCCESS) {
HCCL_ERROR("failed to report MainStream And FirstTask");
return 1;
}
HCCL_DEBUG("[%s]Notify wait on thread[%llu], notifyNumOnMainThread[%u], timeout[%u]",
__func__,
thread,
notifyNumOnMainThread,
CUSTOM_TIMEOUT);
CHK_RET(static_cast<HcclResult>(HcommThreadNotifyWaitOnThread(thread, notifyNumOnMainThread, CUSTOM_TIMEOUT)));
} else {
if (HcommAclrtNotifyWaitOnThread(thread, resCtx->notifyIds[0], CUSTOM_TIMEOUT) != HCCL_SUCCESS) {
HCCL_ERROR("failed to wait notify[%d] from host main stream", resCtx->notifyIds[0]);
return 1;
}
}
if (executor->Orchestrate(*param, resCtx) != HCCL_SUCCESS) {
HCCL_ERROR("orchestrate failed for alg:%s", param->algName);
return 1;
}
if (exportedAicpuTsThread != 0) {
HcomProInfoTmp profInfo;
std::string algTypeStr(param->algTypeStr);
strcpy_s(profInfo.algType, sizeof(profInfo.algType), algTypeStr.c_str());
strcpy_s(profInfo.commName, sizeof(profInfo.commName), param->commName);
profInfo.commNameLen = strlen(param->commName);
profInfo.dataCount = param->DataDes.count;
profInfo.dataType = static_cast<uint8_t>(param->DataDes.dataType);
profInfo.rankSize = resCtx->topoInfo.userRankSize;
HcommProfilingReportDeviceHcclOpInfo(profInfo);
constexpr u32 DEFAULT_NOTIFY_IDX = 0;
HCCL_DEBUG("[%s]Notify record on srcThread[%llu], dstThread[%llu], notifyIdx[%u]",
__func__,
thread,
exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX);
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyRecordOnThread(thread, exportedAicpuTsThread, DEFAULT_NOTIFY_IDX)));
if (HcommProfilingReportMainStreamAndLastTask(thread) != HCCL_SUCCESS) {
HCCL_ERROR("failed to report MainStream And LastTask");
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
if (HcommProfilingEnd(threadHandlePtr, resCtx->slaveThreadNum + 1) != HCCL_SUCCESS) {
HCCL_ERROR("failed to End Profiling");
return 1;
}
} else {
if (HcommAclrtNotifyRecordOnThread(thread, resCtx->notifyIds[1]) != HCCL_SUCCESS) {
HCCL_ERROR("failed to record host main stream");
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
}
}
if (HcommReleaseComm(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommReleaseComm fail, commName[%s]", __func__, param->commName);
return 1;
}
HCCL_INFO("%s success, tag[%s], algTag[%s], commName[%s]", __func__, param->tag, param->algTag, param->commName);
return 0;
}
HcclResult ops_hccl::RestoreVarDataBatchSendRecv(OpParam ¶m)
{
u64 sendRecvItemSize = static_cast<u64>(sizeof(HcclSendRecvItem));
u64 itemNum = static_cast<u64>(param.batchSendRecvDataDes.itemNum);
if (param.varMemSize != itemNum * sendRecvItemSize) {
HCCL_ERROR("param.varMemSize[%lu] is not equal to itemNum[%lu] multiply [HcclSendRecvItem] size[%lu]."
"Failed to restore end recv info for BatchSendRecv!",
param.varMemSize,
itemNum,
sendRecvItemSize);
return HCCL_E_PARA;
}
param.batchSendRecvDataDes.sendRecvItemsPtr = reinterpret_cast<HcclSendRecvItem *>(param.varData);
return HCCL_SUCCESS;
}
HcclResult ops_hccl::RestoreVarDataAlltoAllV(OpParam ¶m, const AlgResourceCtxSerializable &resCtx)
{
u64 rankSize = resCtx.topoInfo.userRankSize;
CHK_PRT_RET(param.varMemSize != ALL_TO_ALL_V_VECTOR_NUM * rankSize * sizeof(u64),
HCCL_ERROR("[RestoreVarDataAlltoAllV] param.varMemSize [%llu] is invalid,"
" ALL_TO_ALL_V_VECTOR_NUM is [%u], rankSize is [%u], sizeof(u64) is [%u],",
param.varMemSize,
ALL_TO_ALL_V_VECTOR_NUM,
rankSize,
sizeof(u64)),
HCCL_E_PARA);
constexpr u32 ALL_TO_ALL_V_OFFSET_SCOUNTS = 0;
constexpr u32 ALL_TO_ALL_V_OFFSET_RECV_COUNTS = 1;
constexpr u32 ALL_TO_ALL_V_OFFSET_SDISPLS = 2;
constexpr u32 ALL_TO_ALL_V_OFFSET_RDISPLS = 3;
u64 *data = reinterpret_cast<u64 *>(param.varData);
param.all2AllVDataDes.sendCounts = data;
param.all2AllVDataDes.recvCounts = data + ALL_TO_ALL_V_OFFSET_RECV_COUNTS * rankSize;
param.all2AllVDataDes.sdispls = data + ALL_TO_ALL_V_OFFSET_SDISPLS * rankSize;
param.all2AllVDataDes.rdispls = data + ALL_TO_ALL_V_OFFSET_RDISPLS * rankSize;
return HCCL_SUCCESS;
}
HcclResult ops_hccl::RestoreVarDataReduceScatterV(OpParam ¶m, const AlgResourceCtxSerializable &resCtx)
{
u64 rankSize = resCtx.topoInfo.userRankSize;
HCCL_INFO("rankSize:%u", rankSize);
CHK_PRT_RET(param.varMemSize != REDUCE_SCATTER_V_VECTOR_NUM * rankSize * sizeof(u64),
HCCL_ERROR("[RestoreVarDataReduceScatterV] param.varMemSize [%llu] is invalid,"
"REDUCE_SCATTER_V_VECTOR_NUM is [%u], rankSize is [%u], sizeof(u64) is [%u],",
param.varMemSize,
REDUCE_SCATTER_V_VECTOR_NUM,
rankSize,
sizeof(u64)),
HCCL_E_PARA);
u64 *data = reinterpret_cast<u64 *>(param.varData);
param.vDataDes.counts = data;
param.vDataDes.displs = data + rankSize;
return HCCL_SUCCESS;
}
HcclResult ops_hccl::RestoreVarDataAllGatherV(OpParam ¶m, const AlgResourceCtxSerializable &resCtx)
{
u64 rankSize = resCtx.topoInfo.userRankSize;
HCCL_INFO("rankSize:%u", rankSize);
CHK_PRT_RET(param.varMemSize != ALL_GATHER_V_VECTOR_NUM * rankSize * sizeof(u64),
HCCL_ERROR("[RestoreVarDataAllGatherV] param.varMemSize [%llu] is invalid,"
"ALL_GATHER_V_VECTOR_NUM is [%u], rankSize is [%u], sizeof(u64) is [%u],",
param.varMemSize,
ALL_GATHER_V_VECTOR_NUM,
rankSize,
sizeof(u64)),
HCCL_E_PARA);
u64 *data = reinterpret_cast<u64 *>(param.varData);
param.vDataDes.counts = data;
for (u64 i = 0; i < rankSize; i++) {
HCCL_INFO("param.vDataDes.counts[%u]:%u", i, reinterpret_cast<u64 *>(param.vDataDes.counts)[i]);
}
param.vDataDes.displs = data + rankSize;
for (u64 i = 0; i < rankSize; i++) {
HCCL_INFO("param.vDataDes.displs[%u]:%u", i, reinterpret_cast<u64 *>(param.vDataDes.displs)[i]);
}
return HCCL_SUCCESS;
}
extern "C" unsigned int HcclLaunchAicpuKernelA3(OpParam *param)
{
if (param == nullptr) {
HCCL_ERROR("%s param is nullptr", __func__);
return 1;
}
HCCL_INFO("Entry-%s, commName[%s], tag[%s], algTag[%s]", __func__, param->commName, param->tag, param->algTag);
if (HcommAcquireComm(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommAcquireComm fail, commName[%s]", __func__, param->commName);
return 1;
}
std::string algName = std::string(param->algName);
if (!ops_hccl::IsOpsV2(param->algName, param->deviceType)) {
ScatterOpInfo opInfo;
if (CreateScatter(param, &opInfo) != HCCL_SUCCESS) {
HCCL_ERROR("%s CreateScatter fail", __func__);
return 1;
}
if (HcommIsSupportHcommRegOpInfo() &&
HcommRegOpInfo(param->commName, reinterpret_cast<void *>(&opInfo), sizeof(ScatterOpInfo)) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommRegOpInfo fail, commName[%s], algTag[%s], size[%u]",
__func__, param->commName, opInfo.algTag, sizeof(ScatterOpInfo));
return 1;
}
if (HcommIsSupportHcommRegOpTaskException() &&
HcommRegOpTaskException(param->commName, ops_hccl::GetScatterOpInfo) != HCCL_SUCCESS) {
HCCL_ERROR(
"%s HcommRegOpTaskException fail, commName[%s], algTag[%s]", __func__, param->commName, param->algTag);
return 1;
}
}
if (ops_hccl::IsOpsV2(param->algName, param->deviceType)) {
HcclCommStatus commStatus = HCCL_COMM_STATUS_INVALID;
if (HcommIsSupportHcclCommGetStatus()) {
auto statusRet = HcclCommGetStatus(param->commName, &commStatus);
if (statusRet != HCCL_SUCCESS) {
HCCL_ERROR("%s HcclCommGetStatus fail, commName[%s], ret = %d", __func__, param->commName, statusRet);
return 1;
}
if (commStatus != HCCL_COMM_STATUS_READY) {
HCCL_ERROR("%s commStatus is not ready!, commStatus = %d", __func__, static_cast<int>(commStatus));
return 1;
}
}
std::shared_ptr<const AlgResourceCtxSerializable> cachedResCtxHolder;
std::unique_ptr<AlgResourceCtxSerializable> resCtx;
const AlgResourceCtxSerializable* resCtxPtr{nullptr};
if (param->opType != HcclCMDType::HCCL_CMD_BATCH_SEND_RECV) {
cachedResCtxHolder = g_cacheManager.Get(param->algTag, param->commName);
if (cachedResCtxHolder != nullptr && IsResCtxCacheReusable(*cachedResCtxHolder, *param)) {
HCCL_INFO("[%s] Cache HIT for algTag[%s]", __func__, param->algTag);
std::string commName = g_cacheManager.ExtractCommName(param->algTag);
if (commName.empty()) commName = param->commName;
CacheStats stats;
size_t cacheSize;
if (g_cacheManager.GetCommStats(commName, stats, cacheSize)) {
HCCL_DEBUG("[%s] comm[%s] hitRate=%.2f%%, cacheSize=%zu",
__func__, commName.c_str(), stats.hitRate() * 100, cacheSize);
}
resCtxPtr = cachedResCtxHolder.get();
} else {
bool isStaleCache = (cachedResCtxHolder != nullptr);
resCtx = DeserializeResCtx(param);
g_cacheManager.Put(param->algTag, *resCtx, param->commName);
resCtxPtr = resCtx.get();
if (isStaleCache) {
HCCL_INFO("[%s] Cache STALE and refreshed for algTag[%s], cachedComm[%p], currentComm[%p]",
__func__, param->algTag, cachedResCtxHolder->commInfoPtr, param->hcclComm);
} else {
HCCL_INFO("[%s] Cache MISS and stored for algTag[%s]", __func__, param->algTag);
}
}
} else {
resCtx = DeserializeResCtx(param);
resCtxPtr = resCtx.get();
}
HcclResult ret = HCCL_SUCCESS;
if (param->opType == HCCL_CMD_BATCH_SEND_RECV) {
ret = ops_hccl::RestoreVarDataBatchSendRecv(*param);
} else if (param->opType == HCCL_CMD_ALLTOALLV || param->opType == HCCL_CMD_ALLTOALLVC ||
param->opType == HCCL_CMD_ALLTOALL) {
ret = ops_hccl::RestoreVarDataAlltoAllV(*param, *resCtxPtr);
} else if (param->opType == HCCL_CMD_REDUCE_SCATTER_V) {
ret = ops_hccl::RestoreVarDataReduceScatterV(*param, *resCtxPtr);
} else if (param->opType == HCCL_CMD_ALLGATHER_V) {
ret = ops_hccl::RestoreVarDataAllGatherV(*param, *resCtxPtr);
}
if (ret != HCCL_SUCCESS) {
HCCL_ERROR("failed to restore optype [%d] data and counts.", param->opType);
return 1;
}
ThreadHandle thread = resCtxPtr->threads[0];
if (HcommBatchModeStart(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set batch mode, tag is %s.", param->algTag);
return 1;
}
HcclDfxOpInfoCompat dfxOpInfo{};
if (ConvertToHcclDfxOpInfo(param, &dfxOpInfo) != HCCL_SUCCESS) {
HCCL_ERROR("ConvertToHcclDfxOpInfo fail, commName is %s, tag is %s", param->commName, param->algTag);
return 1;
}
if (HcclDfxRegOpInfoByCommId(param->commName, reinterpret_cast<void *>(&dfxOpInfo)) != HCCL_SUCCESS) {
HCCL_ERROR("HcclDfxRegOpInfoByCommId fail, commName is %s, tag is %s", param->commName, param->algTag);
return 1;
}
if (HcommProfilingReportKernelStartTask(thread, param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%sfailed to report MainStream And FirstTask, thread %lu, param->commName %s.", __func__, thread, param->commName);
return 1;
}
ThreadHandle exportedAicpuTsThread = param->opThread;
u32 maxNotifyNum = resCtxPtr->notifyNumOnMainThread;
for (u32 i = 0; i < resCtxPtr->notifyNumPerThread.size(); i++) {
if (resCtxPtr->notifyNumPerThread[i] > maxNotifyNum) {
maxNotifyNum = resCtxPtr->notifyNumPerThread[i];
}
}
HCCL_DEBUG("[%s]Notify wait on thread[%llu], maxNotifyNum[%u], timeout[%u]", __func__, thread,
maxNotifyNum, CUSTOM_TIMEOUT);
CHK_RET(static_cast<HcclResult>(HcommThreadNotifyWaitOnThread(thread, maxNotifyNum, CUSTOM_TIMEOUT)));
std::shared_ptr<InsCollAlgBase> executor = CollAlgExecRegistryV2::Instance().GetAlgExec(param->opType, algName);
if (executor.get() == nullptr) {
HCCL_ERROR("Fail to find executor for algName[%s]", algName.c_str());
return 1;
}
ExecTimeoutManager::Instance().SetExecTimeout(param->opConfig.execTimeout);
CHK_RET(InitHcommBatchTransferOnThreadSupported(resCtxPtr->isHcommBatchTransferOnThreadSupported));
if (executor->Orchestrate(*param, *resCtxPtr) != HCCL_SUCCESS) {
HCCL_ERROR("orchestrate failed for alg:%s", param->algName);
return 1;
}
if (HcommProfilingReportKernelEndTask(thread, param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s failed to report MainStream And LastTask, thread %lu, param->commName %s.", __func__, thread, param->commName);
return 1;
}
constexpr u32 DEFAULT_NOTIFY_IDX = 0;
HCCL_DEBUG("[%s]Notify record on srcThread[%llu], dstThread[%llu], notifyIdx[%u]",__func__, thread, exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX);
CHK_RET(static_cast<HcclResult>(HcommThreadNotifyRecordOnThread(thread, exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX)));
if (HcommProfilingReportDeviceOp(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommProfilingReportDeviceOp fail, commName[%s]", __func__, param->commName);
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
} else {
std::unique_ptr<ExecutorBase> executor = CollAlgExecRegistry::Instance().GetAlgExec(algName);
if (executor.get() == nullptr) {
HCCL_ERROR("Fail to find executor for algName[%s]", algName.c_str());
return 1;
}
AlgResourceCtx *resCtx = reinterpret_cast<AlgResourceCtx *>(param->resCtx);
ThreadHandle *threadHandlePtr =
reinterpret_cast<ThreadHandle *>(reinterpret_cast<u8 *>(resCtx) + sizeof(AlgResourceCtx));
ThreadHandle thread = threadHandlePtr[0];
ThreadHandle exportedAicpuTsThread = resCtx->opThread;
u32 notifyNumOnMainThread = resCtx->notifyNumOnMainThread;
if (HcommBatchModeStart(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set batch mode, tag is %s.", param->algTag);
return 1;
}
if (exportedAicpuTsThread != 0) {
if (HcommProfilingInit(threadHandlePtr, resCtx->slaveThreadNum + 1) != HCCL_SUCCESS) {
HCCL_ERROR("failed to init Profiling");
return 1;
}
if (HcommProfilingReportMainStreamAndFirstTask(thread) != HCCL_SUCCESS) {
HCCL_ERROR("failed to report MainStream And FirstTask");
return 1;
}
HCCL_DEBUG("[%s]Notify wait on thread[%llu], notifyNumOnMainThread[%u], timeout[%u]",
__func__,
thread,
notifyNumOnMainThread,
CUSTOM_TIMEOUT);
CHK_RET(static_cast<HcclResult>(HcommThreadNotifyWaitOnThread(thread, notifyNumOnMainThread, CUSTOM_TIMEOUT)));
} else {
if (HcommAclrtNotifyWaitOnThread(thread, resCtx->notifyIds[0], CUSTOM_TIMEOUT) != HCCL_SUCCESS) {
HCCL_ERROR("failed to wait notify[%d] from host main stream", resCtx->notifyIds[0]);
return 1;
}
}
if (executor->Orchestrate(*param, resCtx) != HCCL_SUCCESS) {
HCCL_ERROR("orchestrate failed for alg:%s", param->algName);
return 1;
}
if (exportedAicpuTsThread != 0) {
HcomProInfoTmp profInfo;
std::string algTypeStr(param->algTypeStr);
strcpy_s(profInfo.algType, sizeof(profInfo.algType), algTypeStr.c_str());
strcpy_s(profInfo.commName, sizeof(profInfo.commName), param->commName);
profInfo.commNameLen = strlen(param->commName);
profInfo.dataCount = param->DataDes.count;
profInfo.dataType = static_cast<uint8_t>(param->DataDes.dataType);
profInfo.rankSize = resCtx->topoInfo.userRankSize;
HcommProfilingReportDeviceHcclOpInfo(profInfo);
constexpr u32 DEFAULT_NOTIFY_IDX = 0;
HCCL_DEBUG("[%s]Notify record on srcThread[%llu], dstThread[%llu], notifyIdx[%u]",
__func__,
thread,
exportedAicpuTsThread,
DEFAULT_NOTIFY_IDX);
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyRecordOnThread(thread, exportedAicpuTsThread, DEFAULT_NOTIFY_IDX)));
if (HcommProfilingReportMainStreamAndLastTask(thread) != HCCL_SUCCESS) {
HCCL_ERROR("failed to report MainStream And LastTask");
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
if (HcommProfilingEnd(threadHandlePtr, resCtx->slaveThreadNum + 1) != HCCL_SUCCESS) {
HCCL_ERROR("failed to End Profiling");
return 1;
}
} else {
if (HcommAclrtNotifyRecordOnThread(thread, resCtx->notifyIds[1]) != HCCL_SUCCESS) {
HCCL_ERROR("failed to record host main stream");
return 1;
}
if (HcommBatchModeEnd(param->algTag) != HCCL_SUCCESS) {
HCCL_ERROR("failed set eager mode, tag is %s.", param->algTag);
return 1;
}
}
}
if (HcommReleaseComm(param->commName) != HCCL_SUCCESS) {
HCCL_ERROR("%s HcommReleaseComm fail, commName[%s]", __func__, param->commName);
return 1;
}
HCCL_INFO("%s success, tag[%s], algTag[%s], commName[%s]", __func__, param->tag, param->algTag, param->commName);
return 0;
}