* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "reducer.h"
namespace hccl {
Reducer::Reducer(const HcclDataType dataType, const HcclReduceOp reductionOp, const u64 reduceAttribute)
: dataType_(dataType), reductionOp_(reductionOp), reduceAttribute_(reduceAttribute)
{
SetPreSyncFunc([](){ return HCCL_SUCCESS; });
SetPostSyncFunc([](){ return HCCL_SUCCESS; });
}
Reducer::~Reducer()
{
}
void Reducer::SetPreSyncFunc(std::function<HcclResult()> lambda)
{
preSync_ = std::move(lambda);
}
void Reducer::SetPostSyncFunc(std::function<HcclResult()> lambda)
{
postSync_ = std::move(lambda);
}
HcclResult Reducer::run(const HcclDispatcher dispatcher, const std::shared_ptr<Transport> &link,
const u64 remoteMemOffset, DeviceMem &localSrc, DeviceMem &localDst, DeviceMem &remoteRcvTemp, Stream &stream,
DstMemType resultMem, const UserMemType srcMemType) const
{
CHK_PTR_NULL(localSrc.ptr());
CHK_PTR_NULL(localDst.ptr());
CHK_PTR_NULL(remoteRcvTemp.ptr());
CHK_PTR_NULL(stream.ptr());
HcclResult ret = HCCL_SUCCESS;
u64 dataBytes = remoteRcvTemp.size();
HCCL_DEBUG("localSrc[%p] localDst[%p] remoteRcvtmep[%p] offset[%llu]", localSrc.ptr(), localDst.ptr(),
remoteRcvTemp.ptr(), remoteMemOffset);
bool isSpInlineReduce = link->IsSpInlineReduce();
if (link->IsSupportTransportWithReduce() && (RDMA_REDUCE_BITMASK & reduceAttribute_)) {
CHK_RET(link->RxAsync(UserMemType::INPUT_MEM, remoteMemOffset, localDst.ptr(), localDst.size(), stream));
if (link->GetSupportDataReceivedAck()) {
CHK_RET(link->DataReceivedAck(stream));
}
if (resultMem == DstMemType::RESULT_OUTPUT_MEM) {
ret = HcclD2DMemcpyAsync(dispatcher, localDst, localSrc, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]memcpy_async localSrc[%p] localDst[%p] failed", localSrc.ptr(),
localDst.ptr()),
ret);
}
} else if (link->IsSupportTransportWithReduce() && link->GetLinkType() == LinkType::LINK_STANDARD_ROCE) {
u64 dataCount = localDst.size() / SIZE_TABLE[dataType_];
DeviceMem &reduceSrc = (localSrc == localDst) ? remoteRcvTemp : localSrc;
CHK_RET(link->RxWithReduce(srcMemType, remoteMemOffset, remoteRcvTemp.ptr(), dataBytes,
reduceSrc.ptr(), localDst.ptr(), dataCount, dataType_, reductionOp_, stream, reduceAttribute_));
} else if (isSpInlineReduce && (INLINE_REDUCE_BITMASK & reduceAttribute_)) {
CHK_RET(link->RxDataSignal(stream));
void *remoteMem = nullptr;
CHK_RET(link->GetRemoteMem(srcMemType, &remoteMem));
CHK_RET(HcclReduceAsync(dispatcher, static_cast<s8 *>(remoteMem) + remoteMemOffset,
dataBytes / SIZE_TABLE[dataType_], dataType_, reductionOp_, stream, localSrc.ptr(), link->GetRemoteRank(),
link->GetLinkType(), INLINE_REDUCE_BIT));
if (localSrc != localDst) {
ret = HcclD2DMemcpyAsync(dispatcher, localDst, localSrc, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]memcpy_async localSrc[%p] localDst[%p] failed", localSrc.ptr(),
localDst.ptr()),
ret);
}
if (link -> GetSupportDataReceivedAck()) {
CHK_RET(link->TxAck(stream));
CHK_RET(link->RxAck(stream));
CHK_RET(link->TxDataSignal(stream));
CHK_RET(link->RxDataSignal(stream));
}
} else {
ret = link->RxAsync(UserMemType::INPUT_MEM, remoteMemOffset, remoteRcvTemp.ptr(), dataBytes, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]rx_sync remoteRcvTemp[%p] offset[%llu] size[%llu] "
"failed",
remoteRcvTemp.ptr(), remoteMemOffset, dataBytes),
ret);
if (link->GetSupportDataReceivedAck()) {
ret = link->DataReceivedAck(stream);
CHK_PRT_RET(ret != HCCL_SUCCESS, HCCL_ERROR("[Reducer][Run]rx_sync data received ack failed"), ret);
}
u64 dataCount = localDst.size() / SIZE_TABLE[dataType_];
DeviceMem reduceSrc = (localSrc == localDst) ? remoteRcvTemp : localSrc;
ret = HcclReduceAsync(dispatcher, reduceSrc.ptr(), dataCount, dataType_, reductionOp_, stream, localDst.ptr(),
INVALID_VALUE_RANKID, LinkType::LINK_ONCHIP, reduceAttribute_);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]reduce_async remoteRcvTemp[%p] localSrc[%p] "
"localDst[%p] failed",
remoteRcvTemp.ptr(), localSrc.ptr(), localDst.ptr()),
ret);
}
return ret;
}
HcclResult Reducer::PrepareRxMems(const std::vector<ReducerMemoryInfo> &reducerMems,
std::vector<RxMemoryInfo> &rxMems) const
{
rxMems.reserve(reducerMems.size());
for (const ReducerMemoryInfo &reduceMem : reducerMems) {
rxMems.emplace_back(RxMemoryInfo{ UserMemType::INPUT_MEM, reduceMem.remoteMemOffset,
reduceMem.remoteRcvTemp.ptr(), reduceMem.remoteRcvTemp.size() });
}
return HCCL_SUCCESS;
}
HcclResult Reducer::PrepareRxWithReduceMems(const std::vector<ReducerMemoryInfo> &reducerMems,
std::vector<RxWithReduceMemoryInfo> &rxWithReduceMems) const
{
rxWithReduceMems.reserve(reducerMems.size());
for (const ReducerMemoryInfo &reduceMem : reducerMems) {
u64 dataCount = reduceMem.localdst.size() / SIZE_TABLE[dataType_];
DeviceMem reduceSrc = (reduceMem.localsrc == reduceMem.localdst) ? reduceMem.remoteRcvTemp : reduceMem.localsrc;
rxWithReduceMems.emplace_back(RxWithReduceMemoryInfo{UserMemType::INPUT_MEM, reduceMem.remoteMemOffset,
reduceMem.remoteRcvTemp.ptr(), reduceMem.remoteRcvTemp.size(), reduceSrc.ptr(), reduceMem.localdst.ptr(),
dataCount});
}
return HCCL_SUCCESS;
}
HcclResult Reducer::run(const HcclDispatcher dispatcher, const std::shared_ptr<Transport> &link,
const std::vector<ReducerMemoryInfo> &reducerMems, Stream &stream, DstMemType resultMem) const
{
CHK_PTR_NULL(stream.ptr());
LinkType linkType = link->GetLinkType();
bool isSpInlineReduce = link->IsSpInlineReduce();
bool isSpRdmaReduce = RDMA_REDUCE_BITMASK & reduceAttribute_;
bool isSpTransportWithReduce = link->IsSupportTransportWithReduce();
HcclResult ret = HCCL_SUCCESS;
if (isSpTransportWithReduce && isSpRdmaReduce) {
std::vector<RxMemoryInfo> rxMems;
CHK_RET(PrepareRxMems(reducerMems, rxMems));
CHK_RET(link->RxAsync(rxMems, stream));
if (link->GetSupportDataReceivedAck()) {
CHK_RET(link->DataReceivedAck(stream));
}
CHK_RET(preSync_());
if (resultMem == DstMemType::RESULT_OUTPUT_MEM) {
for (ReducerMemoryInfo reduceMem : reducerMems) {
ret = HcclD2DMemcpyAsync(dispatcher, reduceMem.localdst, reduceMem.localsrc, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]memcpy_async localSrc[%p] localDst[%p] failed", reduceMem.localsrc.ptr(),
reduceMem.localdst.ptr()),
ret);
}
}
CHK_RET(postSync_());
} else if (isSpTransportWithReduce && (linkType == LinkType::LINK_STANDARD_ROCE)) {
std::vector<RxWithReduceMemoryInfo> rxWithReduceMems;
CHK_RET(PrepareRxWithReduceMems(reducerMems, rxWithReduceMems));
CHK_RET(preSync_());
CHK_RET(link->RxWithReduce(rxWithReduceMems, dataType_, reductionOp_, stream, reduceAttribute_));
CHK_RET(postSync_());
} else if (isSpInlineReduce && (INLINE_REDUCE_BITMASK & reduceAttribute_)) {
CHK_RET(link->RxDataSignal(stream));
void *remoteMem = nullptr;
CHK_RET(link->GetRemoteMem(UserMemType::INPUT_MEM, &remoteMem));
CHK_RET(preSync_());
for (ReducerMemoryInfo reduceMem : reducerMems) {
const u64 dataBytes = reduceMem.remoteRcvTemp.size();
CHK_RET(
HcclReduceAsync(dispatcher, static_cast<s8 *>(remoteMem) + reduceMem.remoteMemOffset,
dataBytes / SIZE_TABLE[dataType_], dataType_, reductionOp_, stream, reduceMem.localsrc.ptr(),
link->GetRemoteRank(), link->GetLinkType(), INLINE_REDUCE_BIT));
if (reduceMem.localsrc != reduceMem.localdst) {
ret = HcclD2DMemcpyAsync(dispatcher, reduceMem.localdst, reduceMem.localsrc, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]memcpy_async localSrc[%p] localDst[%p] failed", reduceMem.localsrc.ptr(),
reduceMem.localdst.ptr()),
ret);
}
HCCL_DEBUG("[Reducer][Run]memcpy_async localSrc is [%p]", reduceMem.localsrc.ptr());
}
CHK_RET(postSync_());
} else {
std::vector<RxMemoryInfo> rxMems;
CHK_RET(PrepareRxMems(reducerMems, rxMems));
std::vector<RxWithReduceMemoryInfo> rxWithReduceMems;
CHK_RET(PrepareRxWithReduceMems(reducerMems, rxWithReduceMems));
CHK_RET(preSync_());
CHK_RET(link->RxAsync(rxMems, stream));
CHK_RET(postSync_());
if (link->GetSupportDataReceivedAck()) {
CHK_RET(link->DataReceivedAck(stream));
}
for (RxWithReduceMemoryInfo rxReduceMem : rxWithReduceMems) {
CHK_RET(HcclReduceAsync(dispatcher, rxReduceMem.reduceSrc, rxReduceMem.reduceDataCount, dataType_,
reductionOp_, stream, rxReduceMem.reduceDst, INVALID_VALUE_RANKID, LinkType::LINK_ONCHIP, reduceAttribute_));
}
}
return HCCL_SUCCESS;
}
HcclResult Reducer::run(const HcclDispatcher dispatcher, const std::shared_ptr<Transport> &link,
const std::vector<ReducerMemoryInfo> &reducerMems, u32 notifyIdx, Stream &stream, DstMemType resultMem) const
{
(void) resultMem;
CHK_PTR_NULL(stream.ptr());
CHK_SMART_PTR_NULL(link);
bool isSpInlineReduce = link->IsSpInlineReduce();
HcclResult ret = HCCL_SUCCESS;
if (isSpInlineReduce && static_cast<bool>((INLINE_REDUCE_BITMASK & reduceAttribute_))) {
CHK_RET(link->Wait(notifyIdx, stream));
void *remoteMem = nullptr;
CHK_RET(link->GetRemoteMem(UserMemType::INPUT_MEM, &remoteMem));
CHK_RET(preSync_());
for (ReducerMemoryInfo reduceMem : reducerMems) {
const u64 dataBytes = reduceMem.remoteRcvTemp.size();
CHK_RET(
HcclReduceAsync(dispatcher, static_cast<s8 *>(remoteMem) + reduceMem.remoteMemOffset,
dataBytes / SIZE_TABLE[dataType_], dataType_, reductionOp_, stream, reduceMem.localsrc.ptr(),
link->GetRemoteRank(), link->GetLinkType(), INLINE_REDUCE_BIT));
HCCL_DEBUG("[Reducer][Run]memcpy_async localSrc[%p]", reduceMem.localsrc.ptr());
if (reduceMem.localsrc != reduceMem.localdst) {
ret = HcclD2DMemcpyAsync(dispatcher, reduceMem.localdst, reduceMem.localsrc, stream);
CHK_PRT_RET(ret != HCCL_SUCCESS,
HCCL_ERROR("[Reducer][Run]memcpy_async localSrc[%p] localDst[%p] failed", reduceMem.localsrc.ptr(),
reduceMem.localdst.ptr()),
ret);
}
}
CHK_RET(postSync_());
}
else {
std::vector<RxMemoryInfo> rxMems;
CHK_RET(PrepareRxMems(reducerMems, rxMems));
std::vector<RxWithReduceMemoryInfo> rxWithReduceMems;
CHK_RET(PrepareRxWithReduceMems(reducerMems, rxWithReduceMems));
CHK_RET(preSync_());
CHK_RET(link->Wait(notifyIdx, stream));
for(auto& mem : rxMems){
CHK_PTR_NULL(mem.dst);
void *srcMemPtr = nullptr;
CHK_RET(link->GetRemoteMem(mem.srcMemType, &srcMemPtr));
DeviceMem srcDevMem(static_cast<s8 *>(srcMemPtr) + mem.srcOffset, mem.len);
DeviceMem dstDevMem(static_cast<s8 *>(mem.dst),mem.len);
CHK_RET(HcclD2DMemcpyAsync(dispatcher, dstDevMem, srcDevMem, stream, link->GetRemoteRank(),
link->GetLinkType()));
}
CHK_RET(postSync_());
if (link->GetSupportDataReceivedAck()) {
CHK_RET(link->DataReceivedAck(stream));
}
for (RxWithReduceMemoryInfo rxReduceMem : rxWithReduceMems) {
CHK_RET(HcclReduceAsync(dispatcher, rxReduceMem.reduceSrc, rxReduceMem.reduceDataCount, dataType_,
reductionOp_, stream, rxReduceMem.reduceDst, INVALID_VALUE_RANKID, LinkType::LINK_ONCHIP,
reduceAttribute_));
}
}
return HCCL_SUCCESS;
}
}