* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is part of the MindStudio project.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*/
#include "csrc/common/inject/hccl_inject.h"
#include <functional>
#include <utility>
#include <string>
#include "csrc/callback/callback_manager.h"
#include "csrc/common/function_loader.h"
#include "csrc/common/inject/hccl_range_mark.h"
#include "csrc/common/context_manager.h"
#include "csrc/common/plog_manager.h"
namespace {
enum HcclFunctionIndex {
FUNC_HCCL_ALL_REDUCE,
FUNC_HCCL_BROADCAST,
FUNC_HCCL_ALL_GATHER,
FUNC_HCCL_REDUCE_SCATTER,
FUNC_HCCL_REDUCE,
FUNC_HCCL_ALL_TO_ALL,
FUNC_HCCL_ALL_TO_ALL_V,
FUNC_HCCL_BARRIER,
FUNC_HCCL_SCATTER,
FUNC_HCCL_SEND,
FUNC_HCCL_RECV,
FUNC_HCCL_BATCH_SEND_RECV,
FUNC_HCCL_GET_COMM_NAME,
FUNC_HCCL_GET_RANK_SIZE,
FUNC_HCCL_COUNT
};
pthread_once_t g_once = PTHREAD_ONCE_INIT;
void *g_hcclFuncArray[FUNC_HCCL_COUNT];
void LoadHcclFunction()
{
g_hcclFuncArray[FUNC_HCCL_ALL_REDUCE] = Mspti::Common::RegisterFunction("libhccl", "HcclAllReduce");
g_hcclFuncArray[FUNC_HCCL_BROADCAST] = Mspti::Common::RegisterFunction("libhccl", "HcclBroadcast");
g_hcclFuncArray[FUNC_HCCL_ALL_GATHER] = Mspti::Common::RegisterFunction("libhccl", "HcclAllGather");
g_hcclFuncArray[FUNC_HCCL_REDUCE_SCATTER] = Mspti::Common::RegisterFunction("libhccl", "HcclReduceScatter");
g_hcclFuncArray[FUNC_HCCL_REDUCE] = Mspti::Common::RegisterFunction("libhccl", "HcclReduce");
g_hcclFuncArray[FUNC_HCCL_ALL_TO_ALL] = Mspti::Common::RegisterFunction("libhccl", "HcclAlltoAll");
g_hcclFuncArray[FUNC_HCCL_ALL_TO_ALL_V] = Mspti::Common::RegisterFunction("libhccl", "HcclAlltoAllV");
g_hcclFuncArray[FUNC_HCCL_BARRIER] = Mspti::Common::RegisterFunction("libhccl", "HcclBarrier");
g_hcclFuncArray[FUNC_HCCL_SCATTER] = Mspti::Common::RegisterFunction("libhccl", "HcclScatter");
g_hcclFuncArray[FUNC_HCCL_SEND] = Mspti::Common::RegisterFunction("libhccl", "HcclSend");
g_hcclFuncArray[FUNC_HCCL_RECV] = Mspti::Common::RegisterFunction("libhccl", "HcclRecv");
g_hcclFuncArray[FUNC_HCCL_BATCH_SEND_RECV] = Mspti::Common::RegisterFunction("libhccl", "HcclBatchSendRecv");
g_hcclFuncArray[FUNC_HCCL_GET_COMM_NAME] = Mspti::Common::RegisterFunction("libhccl", "HcclGetCommName");
g_hcclFuncArray[FUNC_HCCL_GET_RANK_SIZE] = Mspti::Common::RegisterFunction("libhccl", "HcclGetRankSize");
}
}
HcclResult HcclAllReduce(VOID_PTR sendBuf, VOID_PTR recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_ALL_REDUCE];
using HcclAllReduceFunc = std::function<decltype(HcclAllReduce)>;
HcclAllReduceFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclAllReduce)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<P2pOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = count;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_ALLREDUCE, __FUNCTION__);
return func(sendBuf, recvBuf, count, dataType, op, comm, stream);
}
HcclResult HcclBroadcast(VOID_PTR buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm,
aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_BROADCAST];
using HcclBroadcastFunc = std::function<decltype(HcclBroadcast)>;
HcclBroadcastFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclBroadcast)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<P2pOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = count;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_BROADCAST, __FUNCTION__);
return func(buf, count, dataType, root, comm, stream);
}
HcclResult HcclAllGather(VOID_PTR sendBuf, VOID_PTR recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm,
aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_ALL_GATHER];
using HcclAllGatherFunc = std::function<decltype(HcclAllGather)>;
HcclAllGatherFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclAllGather)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<CollectiveOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = sendCount;
HcclGetRankSize(comm, static_cast<uint32_t *>(&hcclOpDesc->rankSize));
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_ALLGATHER, __FUNCTION__);
return func(sendBuf, recvBuf, sendCount, dataType, comm, stream);
}
HcclResult HcclReduceScatter(VOID_PTR sendBuf, VOID_PTR recvBuf, uint64_t recvCount, HcclDataType dataType,
HcclReduceOp op, HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_REDUCE_SCATTER];
using HcclReduceScatterFunc = std::function<decltype(HcclReduceScatter)>;
HcclReduceScatterFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclReduceScatter)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<CollectiveOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = recvCount;
HcclGetRankSize(comm, static_cast<uint32_t *>(&hcclOpDesc->rankSize));
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_REDUCE_SCATTER, __FUNCTION__);
return func(sendBuf, recvBuf, recvCount, dataType, op, comm, stream);
}
HcclResult HcclReduce(VOID_PTR sendBuf, VOID_PTR recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
uint32_t root, HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_REDUCE];
using HcclReduceFunc = std::function<decltype(HcclReduce)>;
HcclReduceFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclReduce)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<P2pOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = count;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_REDUCE, __FUNCTION__);
return func(sendBuf, recvBuf, count, dataType, op, root, comm, stream);
}
HcclResult HcclAlltoAll(const VOID_PTR sendBuf, uint64_t sendCount, HcclDataType sendType, const VOID_PTR recvBuf,
uint64_t recvCount, HcclDataType recvType, HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_ALL_TO_ALL];
using HcclAlltoAllFunc = std::function<decltype(HcclAlltoAll)>;
HcclAlltoAllFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclAlltoAll)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<CollectiveOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = sendType;
hcclOpDesc->count = sendCount;
HcclGetRankSize(comm, static_cast<uint32_t *>(&hcclOpDesc->rankSize));
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_ALL_TO_ALL, __FUNCTION__);
return func(sendBuf, sendCount, sendType, recvBuf, recvCount, recvType, comm, stream);
}
HcclResult HcclAlltoAllV(const VOID_PTR sendBuf, const VOID_PTR sendCounts, const VOID_PTR sdispls,
HcclDataType sendType, const VOID_PTR recvBuf, const VOID_PTR recvCounts, const VOID_PTR rdispls,
HcclDataType recvType, HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_ALL_TO_ALL_V];
using HcclAlltoAllVFunc = std::function<decltype(HcclAlltoAllV)>;
HcclAlltoAllVFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclAlltoAllV)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<All2AllVOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->sendType = sendType;
hcclOpDesc->sendCounts = sendCounts;
hcclOpDesc->recvType = recvType;
hcclOpDesc->recvCounts = recvCounts;
HcclGetRankSize(comm, static_cast<uint32_t *>(&hcclOpDesc->rankSize));
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_ALL_TO_ALLV, __FUNCTION__);
return func(sendBuf, sendCounts, sdispls, sendType, recvBuf, recvCounts, rdispls, recvType, comm, stream);
}
HcclResult HcclBarrier(HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_BARRIER];
using HcclBarrierFunc = std::function<decltype(HcclBarrier)>;
HcclBarrierFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclBarrier)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_BARRIER, __FUNCTION__);
return func(comm, stream);
}
HcclResult HcclScatter(VOID_PTR sendBuf, VOID_PTR recvBuf, uint64_t recvCount, HcclDataType dataType, uint32_t root,
HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_SCATTER];
using HcclScatterFunc = std::function<decltype(HcclScatter)>;
HcclScatterFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclScatter)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<CollectiveOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = recvCount;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_SCATTER, __FUNCTION__);
return func(sendBuf, recvBuf, recvCount, dataType, root, comm, stream);
}
HcclResult HcclSend(VOID_PTR sendBuf, uint64_t count, HcclDataType dataType, uint32_t destRank, HcclComm comm,
aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_SEND];
using HcclSendFunc = std::function<decltype(HcclSend)>;
HcclSendFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclSend)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<P2pOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = count;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_SEND, __FUNCTION__);
return func(sendBuf, count, dataType, destRank, comm, stream);
}
HcclResult HcclRecv(VOID_PTR recvBuf, uint64_t count, HcclDataType dataType, uint32_t srcRank, HcclComm comm,
aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_RECV];
using HcclRecvFunc = std::function<decltype(HcclRecv)>;
HcclRecvFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclRecv)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<P2pOpDesc> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
hcclOpDesc->dataType = dataType;
hcclOpDesc->count = count;
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_RECV, __FUNCTION__);
return func(recvBuf, count, dataType, srcRank, comm, stream);
}
HcclResult HcclBatchSendRecv(HcclSendRecvItem *sendRecvInfo, uint32_t itemNum, HcclComm comm, aclrtStream stream)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_BATCH_SEND_RECV];
using HcclBatchSendRecvFunc = std::function<decltype(HcclBatchSendRecv)>;
HcclBatchSendRecvFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclBatchSendRecv)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
std::shared_ptr<BatchSendRecvOp> hcclOpDesc;
Mspti::Common::MsptiMakeSharedPtr(hcclOpDesc);
if (sendRecvInfo != nullptr) {
for (uint32_t i = 0; i < itemNum; i++) {
P2pOpDesc batchSendRecvItemDesc;
batchSendRecvItemDesc.dataType = sendRecvInfo[i].dataType;
batchSendRecvItemDesc.count = sendRecvInfo[i].count;
hcclOpDesc->batchSendRecvItem.emplace_back(batchSendRecvItemDesc);
}
} else {
MSPTI_LOGE("op HcclBatchSendRecv, sendRecvInfo is nullptr");
}
hcclOpDesc->opName = __FUNCTION__;
auto marker = Mspti::Parser::HcclMarkFactory::createMarker(stream, comm, hcclOpDesc);
Mspti::Callback::CallbackScope scope(MSPTI_CB_DOMAIN_HCCL, MSPTI_CBID_HCCL_SENDRECV, __FUNCTION__);
return func(sendRecvInfo, itemNum, comm, stream);
}
HcclResult HcclGetCommName(HcclComm comm, char *commName)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_GET_COMM_NAME];
using HcclGetCommNameFunc = std::function<decltype(HcclGetCommName)>;
HcclGetCommNameFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclGetCommName)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
return func(comm, commName);
}
HcclResult HcclGetRankSize(HcclComm comm, uint32_t *rankSize)
{
pthread_once(&g_once, LoadHcclFunction);
void *voidFunc = g_hcclFuncArray[FUNC_HCCL_GET_RANK_SIZE];
using HcclGetRankSizeFunc = std::function<decltype(HcclGetRankSize)>;
HcclGetRankSizeFunc func = Mspti::Common::ReinterpretConvert<decltype(&HcclGetRankSize)>(voidFunc);
if (func == nullptr) {
Mspti::Common::GetFunction("libhccl", __FUNCTION__, func);
}
THROW_FUNC_NOTFOUND(func, __FUNCTION__, "libhccl.so");
return func(comm, rankSize);
}