* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ins_send_executor.h"
#include "alg_data_trans_wrapper.h"
namespace ops_hccl {
std::string InsSendExecutor::Describe() const {
return "Instruction based Send Executor.";
}
HcclResult InsSendExecutor::InitSendInfo(
const HcclComm comm, const OpParam ¶m, const TopoInfoWithNetLayerDetails *topoInfo)
{
(void) comm;
myRank_ = topoInfo->userRank;
rankSize_ = topoInfo->userRankSize;
devType_ = topoInfo->deviceType;
remoteRank_ = param.sendRecvRemoteRank;
dataCount_ = param.DataDes.count;
dataType_ = param.DataDes.dataType;
dataTypeSize_ = static_cast<u64>(DATATYPE_SIZE_TABLE[dataType_]);
HCCL_INFO(
"[InsSendExecutor][InitSendInfo] myRank [%u], remoteRank [%u], rankSize [%u], devType [%u], "
"dataType [%u] dataTypeSize [%u]",
myRank_, remoteRank_, rankSize_, devType_, dataType_, dataTypeSize_);
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::CalcAlgHierarchyInfo(
HcclComm comm, TopoInfoWithNetLayerDetails *topoInfo, AlgHierarchyInfoForAllLevel &algHierarchyInfo)
{
myRank_ = topoInfo->userRank;
HCCL_DEBUG("[InsSendExecutor][CalcAlgHierarchyInfo][%d] Start.", myRank_);
CHK_PRT_RET(
(topoInfo->userRankSize == 0),
HCCL_ERROR("[InsSendExecutor][CalcAlgHierarchyInfo] Rank [%d], rankSize is 0.", myRank_),
HcclResult::HCCL_E_PARA);
algHierarchyInfo.infos.resize(1);
algHierarchyInfo.infos[0].resize(1);
algHierarchyInfo.infos[0][0].clear();
for (uint32_t rankId = 0; rankId < topoInfo->userRankSize; rankId++) {
algHierarchyInfo.infos[0][0].push_back(rankId);
}
HCCL_DEBUG("[InsSendExecutor][CalcAlgHierarchyInfo][%d] Success.", myRank_);
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::CalcRes(
HcclComm comm, const OpParam ¶m, const TopoInfoWithNetLayerDetails *topoInfo,
const AlgHierarchyInfoForAllLevel &algHierarchyInfo, AlgResourceRequest &resourceRequest)
{
#ifndef AICPU_COMPILE
InitSendInfo(comm, param, topoInfo);
HCCL_DEBUG("[InsSendExecutor][CalcRes][%d]->[%d] Start.", myRank_, remoteRank_);
resourceRequest.notifyNumOnMainThread = 0;
resourceRequest.slaveThreadNum = 0;
std::vector<HcclChannelDesc> level0Channels;
bool isGroupEnabled = false;
if (HcommIsSupportHcclGroupStatusGet()) {
CHK_RET(HcclGroupStatusGet(&isGroupEnabled));
}
if (isGroupEnabled) {
CHK_RET(CreateChannelRequestByRankId(comm, param, myRank_, remoteRank_, level0Channels, 2));
} else {
CHK_RET(CreateChannelRequestByRankId(comm, param, myRank_, remoteRank_, level0Channels));
}
resourceRequest.channels.push_back(level0Channels);
HCCL_DEBUG("[InsSendExecutor][CalcRes][%d]->[%d] Success.", myRank_, remoteRank_);
#endif
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::Orchestrate(const OpParam ¶m, const AlgResourceCtxSerializable &resCtx) {
opMode_ = param.opMode;
myRank_ = resCtx.topoInfo.userRank;
remoteRank_ = param.sendRecvRemoteRank;
maxTmpMemSize_ = resCtx.cclMem.size;
dataCount_ = param.DataDes.count;
dataType_ = param.DataDes.dataType;
dataTypeSize_ = static_cast<u64>(DATATYPE_SIZE_TABLE[dataType_]);
dataSize_ = dataCount_ * dataTypeSize_;
HCCL_DEBUG("[InsSendExecutor][Orchestrate][%d]->[%d] Start.", myRank_, remoteRank_);
const ThreadHandle &thread = resCtx.threads.at(0);
auto channelIt = std::find_if(
resCtx.channels.at(0).begin(), resCtx.channels.at(0).end(),
[this](const ChannelInfo &channel_) {
return channel_.remoteRank == remoteRank_;
});
CHK_PRT_RET(
channelIt == resCtx.channels.at(0).end(),
HCCL_ERROR("[InsSendExecutor][Orchestrate] Channel[%d]-[%d] not found.", myRank_, remoteRank_),
HcclResult::HCCL_E_NOT_FOUND);
const ChannelInfo &channel = *channelIt;
if (channel.protocol == CommProtocol::COMM_PROTOCOL_PCIE) {
isDmaRead_ = true;
}
if (opMode_ == OpMode::OFFLOAD) {
CHK_RET(OrchestrateOffload(param, resCtx, thread, channel));
} else {
CHK_RET(OrchestrateOpbase(param, resCtx, thread, channel));
}
HCCL_DEBUG("[InsSendExecutor][Orchestrate][%d]->[%d] Success.", myRank_, remoteRank_);
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::OrchestrateP2p(const OpParam ¶m, const AlgResourceCtxSerializable &resCtx,
ThreadHandle sendRecvThread) {
opMode_ = param.opMode;
myRank_ = resCtx.topoInfo.userRank;
remoteRank_ = param.sendRecvRemoteRank;
maxTmpMemSize_ = resCtx.cclMem.size;
dataCount_ = param.DataDes.count;
dataType_ = param.DataDes.dataType;
dataTypeSize_ = static_cast<u64>(DATATYPE_SIZE_TABLE[dataType_]);
dataSize_ = dataCount_ * dataTypeSize_;
HCCL_DEBUG("[InsSendExecutor][OrchestrateP2p][%d]->[%d] Start.", myRank_, remoteRank_);
HCCL_INFO("[InsSendExecutor][OrchestrateP2p] resCtx channels size [%zu].", resCtx.channels.at(0).size());
const ChannelInfo &channel
= (resCtx.channels.at(0).size() > 1)
? ((myRank_ <= remoteRank_) ? resCtx.channels.at(0).at(0) : resCtx.channels.at(0).at(1))
: resCtx.channels.at(0).at(0);
if (channel.protocol == CommProtocol::COMM_PROTOCOL_PCIE) {
isDmaRead_ = true;
}
if (opMode_ == OpMode::OFFLOAD) {
CHK_RET(OrchestrateOffload(param, resCtx, sendRecvThread, channel));
} else {
CHK_RET(OrchestrateOpbase(param, resCtx, sendRecvThread, channel));
}
HCCL_DEBUG("[InsSendExecutor][OrchestrateP2p][%d]->[%d] Success.", myRank_, remoteRank_);
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::OrchestrateOffload(const OpParam ¶m, const AlgResourceCtxSerializable &resCtx, const ThreadHandle &thread, const ChannelInfo &channel) {
(void) resCtx;
void *dstBufferPtr = static_cast<void *>(channel.remoteOutputGraphMode.addr);
maxLoopTransSize_ = UB_MAX_DATA_SIZE;
maxLoopTransCount_ = maxLoopTransSize_ / dataTypeSize_;
u64 dataCountToSend = dataCount_;
u64 currentOffset = 0;
std::vector<DataSlice> srcSlices;
std::vector<DataSlice> dstSlices;
HCCL_DEBUG("[InsSendExecutor][Orchestrate][%d]->[%d] OFFLOAD Generating tasks.", myRank_, remoteRank_);
while (dataCountToSend > 0) {
u64 transferCount = dataCountToSend > maxLoopTransCount_ ? maxLoopTransCount_ : dataCountToSend;
u64 transferSize = transferCount * dataTypeSize_;
srcSlices.emplace_back(param.inputPtr, currentOffset, transferSize, transferCount);
dstSlices.emplace_back(dstBufferPtr, currentOffset, transferSize, transferCount);
currentOffset = currentOffset + transferSize;
dataCountToSend = dataCountToSend - transferCount;
}
SlicesList sendSlicesList{srcSlices, dstSlices};
DataInfo sendInfo{channel, sendSlicesList};
if (isDmaRead_) {
CHK_RET(SendRead(sendInfo, thread));
} else {
CHK_RET(SendWrite(sendInfo, thread));
}
return HcclResult::HCCL_SUCCESS;
}
HcclResult InsSendExecutor::OrchestrateOpbase(const OpParam ¶m, const AlgResourceCtxSerializable &resCtx, const ThreadHandle &thread, const ChannelInfo &channel) {
maxLoopTransSize_ = std::min<u64>(UB_MAX_DATA_SIZE, channel.remoteCclMem.size);
maxLoopTransCount_ = maxLoopTransSize_ / dataTypeSize_;
u64 dataCountToSend = dataCount_;
u64 currentOffset = 0;
HCCL_DEBUG("[InsSendExecutor][Orchestrate][%d]->[%d] OPBASE Generating tasks.", myRank_, remoteRank_);
while (dataCountToSend > 0) {
u64 transferCount = dataCountToSend > maxLoopTransCount_ ? maxLoopTransCount_ : dataCountToSend;
u64 transferSize = transferCount * dataTypeSize_;
DataSlice inputSlice{param.inputPtr, currentOffset, transferSize, transferCount};
DataSlice remoteCclSlice{channel.remoteCclMem.addr, 0, transferSize, transferCount};
if (isDmaRead_) {
DataSlice cclSlice{resCtx.cclMem.addr, 0, transferSize, transferCount};
CHK_RET(LocalCopy(thread, inputSlice, cclSlice));
SlicesList sendSlicesList{{cclSlice}, {remoteCclSlice}};
DataInfo sendInfo{channel, sendSlicesList};
CHK_RET(SendRead(sendInfo, thread));
} else {
SlicesList sendSlicesList{{inputSlice}, {remoteCclSlice}};
DataInfo sendInfo{channel, sendSlicesList};
CHK_RET(SendWrite(sendInfo, thread));
}
currentOffset = currentOffset + transferSize;
dataCountToSend = dataCountToSend - transferCount;
}
return HcclResult::HCCL_SUCCESS;
}
REGISTER_EXECUTOR_IMPL(HcclCMDType::HCCL_CMD_SEND, InsSend, InsSendExecutor);
}