* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "alg_data_trans_wrapper.h"
#include "exec_timeout_manager.h"
#include "hcomm_primitives_dl.h"
#include <atomic>
#include <limits>
#include <algorithm>
#include <vector>
namespace ops_hccl {
namespace {
HcclResult HcommChannelNotifyWaitOnThread(ThreadHandle thread, ChannelHandle channel, u32 localNotifyIdx, u32 timeout)
{
return HcclChannelNotifyWaitOnThreadDefault(thread, channel, localNotifyIdx, timeout);
}
HcclResult HcommThreadNotifyWaitOnThread(ThreadHandle thread, u32 notifyIdx, u32 timeout)
{
return HcclThreadNotifyWaitOnThreadDefault(thread, notifyIdx, timeout);
}
enum HcommBatchTransferSupportState {
HCOMM_BATCH_TRANSFER_UNINIT = -1,
HCOMM_BATCH_TRANSFER_UNSUPPORTED = 0,
HCOMM_BATCH_TRANSFER_SUPPORTED = 1,
};
std::atomic<int> g_hcommBatchTransferSupportState{HCOMM_BATCH_TRANSFER_UNINIT};
void *GetSliceAddr(const DataSlice &slice)
{
return static_cast<void *>(static_cast<s8 *>(slice.addr_) + slice.offset_);
}
void TraceDataSlice(const char *funcName, const char *transType, u32 sliceIdx, u32 sliceNum,
const DataSlice &srcSlice, const DataSlice &dstSlice, const void *src, const void *dst, u64 len,
HcclDataType dataType, HcclReduceOp reduceOp)
{
HCCL_DEBUG("[AlgDataTransWrapper][%s][%s] sliceIdx[%u], sliceNum[%u], srcBase[%p], "
"srcOffset[%llu], srcAddr[%p], srcSize[%llu], srcCount[%llu], dstBase[%p], "
"dstOffset[%llu], dstAddr[%p], dstSize[%llu], dstCount[%llu], len[%llu], "
"dataType[%d], reduceOp[%d].",
funcName, transType, sliceIdx, sliceNum, srcSlice.addr_,
static_cast<unsigned long long>(srcSlice.offset_), src,
static_cast<unsigned long long>(srcSlice.size_), static_cast<unsigned long long>(srcSlice.count_),
dstSlice.addr_, static_cast<unsigned long long>(dstSlice.offset_), dst,
static_cast<unsigned long long>(dstSlice.size_), static_cast<unsigned long long>(dstSlice.count_),
static_cast<unsigned long long>(len), static_cast<int>(dataType), static_cast<int>(reduceOp));
}
void TraceBatchSummary(const char *funcName, const char *transType, u32 totalSliceNum, u32 validSliceNum,
const ChannelInfo &channel)
{
HCCL_DEBUG("[AlgDataTransWrapper][%s][%s] totalSliceNum[%u], validSliceNum[%u], "
"channelHandle[%llu].",
funcName, transType, totalSliceNum, validSliceNum, static_cast<unsigned long long>(channel.handle));
}
HcclHcommBatchTransferDesc MakeBatchTransDesc(HcclHcommTransferType transType, void *dst, void *src, u64 len)
{
HcclHcommBatchTransferDesc desc = {};
desc.transType = transType;
if (transType == HCCL_HCOMM_TRANSFER_TYPE_READ) {
desc.transferInfo.read.len = len;
desc.transferInfo.read.dst = dst;
desc.transferInfo.read.src = src;
} else {
desc.transferInfo.write.len = len;
desc.transferInfo.write.dst = dst;
desc.transferInfo.write.src = src;
}
return desc;
}
HcclHcommBatchTransferDesc MakeBatchReduceDesc(HcclHcommTransferType transType, void *dst, void *src, u64 count,
HcclDataType dataType, HcclReduceOp reduceOp)
{
HcclHcommBatchTransferDesc desc = {};
desc.transType = transType;
desc.transferInfo.reduce.count = count;
desc.transferInfo.reduce.dst = dst;
desc.transferInfo.reduce.src = src;
desc.transferInfo.reduce.dataType = static_cast<HcommDataType>(dataType);
desc.transferInfo.reduce.reduceOp = static_cast<HcommReduceOp>(reduceOp);
return desc;
}
template<typename ProcessSliceFunc>
HcclResult RunBatchTransfer(const ThreadHandle &thread, const ChannelInfo &channel,
const std::vector<DataSlice> &srcSlices, const std::vector<DataSlice> &dstSlices,
const char *funcName, const char *transType, ProcessSliceFunc processSlice)
{
u32 repeatNum = srcSlices.size();
std::vector<HcclHcommBatchTransferDesc> transferDescs;
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] %s: size is 0.", funcName);
continue;
}
CHK_RET(processSlice(i, srcSlice, dstSlice, transferDescs, repeatNum));
}
if (transferDescs.size() > 0) {
TraceBatchSummary(funcName, transType, repeatNum, transferDescs.size(), channel);
CHK_RET(static_cast<HcclResult>(
HcclHcommBatchTransferOnThread(thread, channel.handle, transferDescs.data(),
static_cast<u32>(transferDescs.size()))));
}
return HCCL_SUCCESS;
}
template<typename SendRecvInfoType, typename ProcessSliceFunc, typename FallbackFunc>
HcclResult DoSendRecvBatchTx(const SendRecvInfoType &sendRecvInfo, const ThreadHandle &thread,
const char *funcName, const char *transType, ProcessSliceFunc processSlice, FallbackFunc fallback)
{
if (!IsHcommBatchTransferOnThreadSupported()) {
return fallback(sendRecvInfo, thread);
}
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
CHK_RET(RunBatchTransfer(thread, sendChannel, srcSlices, dstSlices, funcName, transType, processSlice));
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
template<typename DataInfoType, typename ProcessSliceFunc, typename FallbackFunc>
HcclResult DoSendBatchTx(const DataInfoType &sendInfo, const ThreadHandle &thread,
const char *funcName, const char *transType, ProcessSliceFunc processSlice, FallbackFunc fallback)
{
if (!IsHcommBatchTransferOnThreadSupported()) {
return fallback(sendInfo, thread);
}
const std::vector<DataSlice> srcSlices = sendInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendInfo.slices_.dstSlices_;
const ChannelInfo &sendChannel = sendInfo.channel_;
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
CHK_RET(RunBatchTransfer(thread, sendChannel, srcSlices, dstSlices, funcName, transType, processSlice));
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
template<typename DataInfoType, typename ProcessSliceFunc, typename FallbackFunc>
HcclResult DoRecvBatchRx(const DataInfoType &recvInfo, const ThreadHandle &thread,
const char *funcName, const char *transType, ProcessSliceFunc processSlice, FallbackFunc fallback)
{
if (!IsHcommBatchTransferOnThreadSupported()) {
return fallback(recvInfo, thread);
}
const std::vector<DataSlice> srcSlices = recvInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = recvInfo.slices_.dstSlices_;
const ChannelInfo &recvChannel = recvInfo.channel_;
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
CHK_RET(RunBatchTransfer(thread, recvChannel, srcSlices, dstSlices, funcName, transType, processSlice));
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
template<typename SendRecvInfoType, typename ProcessSliceFunc, typename FallbackFunc>
HcclResult DoSendRecvBatchRx(const SendRecvInfoType &sendRecvInfo, const ThreadHandle &thread,
const char *funcName, const char *transType, ProcessSliceFunc processSlice, FallbackFunc fallback)
{
if (!IsHcommBatchTransferOnThreadSupported()) {
return fallback(sendRecvInfo, thread);
}
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
CHK_RET(RunBatchTransfer(thread, recvChannel, srcSlices, dstSlices, funcName, transType, processSlice));
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
}
HcclResult InitHcommBatchTransferOnThreadSupported(bool isSupported)
{
int target = isSupported ? HCOMM_BATCH_TRANSFER_SUPPORTED : HCOMM_BATCH_TRANSFER_UNSUPPORTED;
int expected = HCOMM_BATCH_TRANSFER_UNINIT;
if (g_hcommBatchTransferSupportState.compare_exchange_strong(expected, target)) {
return HCCL_SUCCESS;
}
if (expected != target) {
HCCL_ERROR("[AlgDataTransWrapper] HcommBatchTransferOnThread support mismatch, cached[%d], ctx[%d].",
expected, target);
return HCCL_E_INTERNAL;
}
return HCCL_SUCCESS;
}
bool IsHcommBatchTransferOnThreadSupported()
{
return g_hcommBatchTransferSupportState.load() == HCOMM_BATCH_TRANSFER_SUPPORTED;
}
HcclResult SendWrite(const DataInfo &sendInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendInfo.slices_.dstSlices_;
const ChannelInfo &sendChannel = sendInfo.channel_;
u32 sliceNum = srcSlices.size();
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < sliceNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendWrite: size is 0.");
continue;
}
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendWrite", "WRITE", i, sliceNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommWriteOnThread(thread, sendChannel.handle, dst, src, srcSlice.size_)));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
HcclResult SendBatchWrite(const DataInfo &sendInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendBatchWrite", "BATCH_WRITE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
transferDescs.push_back(MakeBatchTransDesc(HCCL_HCOMM_TRANSFER_TYPE_WRITE, dst, src, srcSlice.size_));
return HCCL_SUCCESS;
};
return DoSendBatchTx(sendInfo, thread, "SendBatchWrite", "BATCH_WRITE", processSlice, SendWrite);
}
HcclResult RecvWrite(const DataInfo &recvInfo, const ThreadHandle &thread)
{
const ChannelInfo &recvChannel = recvInfo.channel_;
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
这个SendRecv是以notify的视角去看的,针对一个thread上的notify,即有record操作也有wait操作。
为什么是SendRecv:因为是一个双向的写,rank 0需要向rank 1写,而rank 1也需要向rank 0写,
因此对于rank 0来说需要向rank 1 record告诉rank 1自己准备好了可以写了,
而rank 0也需要wait一下rank 1的record知道rnak 1那边也可以写了。
*/
HcclResult SendRecvWrite(const SendRecvInfo &sendRecvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
u32 repeatNum = srcSlices.size();
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendRecvWrite: size is 0.");
continue;
}
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvWrite", "WRITE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendRecvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommWriteOnThread(thread, sendChannel.handle, dst, src, srcSlice.size_)));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult SendRecvBatchWrite(const SendRecvInfo &sendRecvInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendRecvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvBatchWrite", "BATCH_WRITE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendRecvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
transferDescs.push_back(MakeBatchTransDesc(HCCL_HCOMM_TRANSFER_TYPE_WRITE, dst, src, srcSlice.size_));
return HCCL_SUCCESS;
};
return DoSendRecvBatchTx(sendRecvInfo, thread, "SendRecvBatchWrite", "BATCH_WRITE", processSlice, SendRecvWrite);
}
HcclResult SendRecvBatchWriteReduce(const SendRecvReduceInfo &sendRecvInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendRecvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvBatchWriteReduce: src slice count [%u] is not mate to src slice "
"size [%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvBatchWriteReduce: dst slice count [%u] is not mate to dst slice "
"size [%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
u64 len = srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_];
TraceDataSlice("SendRecvBatchWriteReduce", "BATCH_WRITE_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
len, sendRecvInfo.dataType_, sendRecvInfo.reduceType_);
transferDescs.push_back(MakeBatchReduceDesc(HCCL_HCOMM_TRANSFER_TYPE_WRITE_REDUCE, dst, src, srcSlice.count_,
sendRecvInfo.dataType_, sendRecvInfo.reduceType_));
return HCCL_SUCCESS;
};
return DoSendRecvBatchTx(sendRecvInfo, thread, "SendRecvBatchWriteReduce", "BATCH_WRITE_REDUCE",
processSlice, SendRecvWriteReduce);
}
HcclResult SendWriteReduce(const DataReduceInfo &sendInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendInfo.slices_.dstSlices_;
const ChannelInfo &sendChannel = sendInfo.channel_;
u32 repeatNum = srcSlices.size();
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendWriteReduce: size is 0.");
continue;
}
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendWriteReduce: src slice count [%u] is not mate to src slice size "
"[%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendWriteReduce: dst slice count [%u] is not mate to dst slice size "
"[%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendWriteReduce", "WRITE_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.count_, sendInfo.dataType_, sendInfo.reduceType_);
CHK_RET(static_cast<HcclResult>(HcommWriteReduceOnThread(thread,
sendChannel.handle,
dst,
src,
srcSlice.count_,
static_cast<HcommDataType>(sendInfo.dataType_),
static_cast<HcommReduceOp>(sendInfo.reduceType_))));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
HcclResult SendBatchWriteReduce(const DataReduceInfo &sendInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendBatchWriteReduce: src slice count [%u] is not mate to src slice "
"size [%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendBatchWriteReduce: dst slice count [%u] is not mate to dst slice "
"size [%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
u64 len = srcSlice.count_ * DATATYPE_SIZE_TABLE[sendInfo.dataType_];
TraceDataSlice("SendBatchWriteReduce", "BATCH_WRITE_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
len, sendInfo.dataType_, sendInfo.reduceType_);
transferDescs.push_back(MakeBatchReduceDesc(HCCL_HCOMM_TRANSFER_TYPE_WRITE_REDUCE, dst, src, srcSlice.count_,
sendInfo.dataType_, sendInfo.reduceType_));
return HCCL_SUCCESS;
};
return DoSendBatchTx(sendInfo, thread, "SendBatchWriteReduce", "BATCH_WRITE_REDUCE",
processSlice, SendWriteReduce);
}
HcclResult RecvWriteReduce(const DataReduceInfo &recvInfo, const ThreadHandle &thread)
{
const ChannelInfo &recvChannel = recvInfo.channel_;
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult SendRecvWriteReduce(const SendRecvReduceInfo &sendRecvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.txSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
u32 repeatNum = srcSlices.size();
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendRecvWriteReduce: size is 0.");
continue;
}
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvWriteReduce: src slice count [%u] is not mate to src slice size "
"[%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvWriteReduce: dst slice count [%u] is not mate to dst slice size "
"[%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvWriteReduce", "WRITE_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.count_, sendRecvInfo.dataType_, sendRecvInfo.reduceType_);
CHK_RET(static_cast<HcclResult>(HcommWriteReduceOnThread(thread,
sendChannel.handle,
dst,
src,
srcSlice.count_,
static_cast<HcommDataType>(sendRecvInfo.dataType_),
static_cast<HcommReduceOp>(sendRecvInfo.reduceType_))));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult SendRead(const DataInfo &sendInfo, const ThreadHandle &thread)
{
const ChannelInfo &sendChannel = sendInfo.channel_;
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult RecvRead(const DataInfo &recvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = recvInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = recvInfo.slices_.dstSlices_;
const ChannelInfo &recvChannel = recvInfo.channel_;
u32 repeatNum = srcSlices.size();
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] RecvRead: size is 0.");
continue;
}
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("RecvRead", "READ", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, recvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommReadOnThread(thread, recvChannel.handle, dst, src, srcSlice.size_)));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
HcclResult RecvBatchRead(const DataInfo &recvInfo, const ThreadHandle &thread)
{
auto processSlice = [&recvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("RecvBatchRead", "BATCH_READ", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, recvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
transferDescs.push_back(MakeBatchTransDesc(HCCL_HCOMM_TRANSFER_TYPE_READ, dst, src, srcSlice.size_));
return HCCL_SUCCESS;
};
return DoRecvBatchRx(recvInfo, thread, "RecvBatchRead", "BATCH_READ", processSlice, RecvRead);
}
HcclResult SendRecvRead(const SendRecvInfo &sendRecvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
u32 repeatNum = srcSlices.size();
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendRecvRead: size is 0.");
continue;
}
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvRead", "READ", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendRecvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommReadOnThread(thread, recvChannel.handle, dst, src, srcSlice.size_)));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult SendRecvBatchRead(const SendRecvInfo &sendRecvInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendRecvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvBatchRead", "BATCH_READ", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.size_, sendRecvInfo.dataType_, HcclReduceOp::HCCL_REDUCE_RESERVED);
transferDescs.push_back(MakeBatchTransDesc(HCCL_HCOMM_TRANSFER_TYPE_READ, dst, src, srcSlice.size_));
return HCCL_SUCCESS;
};
return DoSendRecvBatchRx(sendRecvInfo, thread, "SendRecvBatchRead", "BATCH_READ", processSlice, SendRecvRead);
}
HcclResult SendReadReduce(const DataReduceInfo &sendInfo, const ThreadHandle &thread)
{
const ChannelInfo &sendChannel = sendInfo.channel_;
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult RecvReadReduce(const DataReduceInfo &recvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = recvInfo.slices_.srcSlices_;
const std::vector<DataSlice> dstSlices = recvInfo.slices_.dstSlices_;
const ChannelInfo &recvChannel = recvInfo.channel_;
u32 repeatNum = srcSlices.size();
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] RecvReadReduce: size is 0.");
continue;
}
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[recvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] RecvReadReduce: src slice count [%u] is not mate to src slice size "
"[%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
recvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[recvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] RecvReadReduce: dst slice count [%u] is not mate to dst slice size "
"[%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
recvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("RecvReadReduce", "READ_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.count_, recvInfo.dataType_, recvInfo.reduceType_);
CHK_RET(static_cast<HcclResult>(HcommReadReduceOnThread(thread,
recvChannel.handle,
dst,
src,
srcSlice.count_,
static_cast<HcommDataType>(recvInfo.dataType_),
static_cast<HcommReduceOp>(recvInfo.reduceType_))));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
return HCCL_SUCCESS;
}
HcclResult RecvBatchReadReduce(const DataReduceInfo &recvInfo, const ThreadHandle &thread)
{
auto processSlice = [&recvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[recvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] RecvBatchReadReduce: src slice count [%u] is not mate to src slice "
"size [%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
recvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[recvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] RecvBatchReadReduce: dst slice count [%u] is not mate to dst slice "
"size [%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
recvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
u64 len = srcSlice.count_ * DATATYPE_SIZE_TABLE[recvInfo.dataType_];
TraceDataSlice("RecvBatchReadReduce", "BATCH_READ_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
len, recvInfo.dataType_, recvInfo.reduceType_);
transferDescs.push_back(MakeBatchReduceDesc(HCCL_HCOMM_TRANSFER_TYPE_READ_REDUCE, dst, src, srcSlice.count_,
recvInfo.dataType_, recvInfo.reduceType_));
return HCCL_SUCCESS;
};
return DoRecvBatchRx(recvInfo, thread, "RecvBatchReadReduce", "BATCH_READ_REDUCE",
processSlice, RecvReadReduce);
}
HcclResult SendRecvReadReduce(const SendRecvReduceInfo &sendRecvInfo, const ThreadHandle &thread)
{
const std::vector<DataSlice> srcSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.srcSlices_;
const std::vector<DataSlice> dstSlices = sendRecvInfo.sendRecvSlices_.rxSlicesList_.dstSlices_;
const ChannelInfo &sendChannel = sendRecvInfo.sendRecvChannels_.txChannel_;
const ChannelInfo &recvChannel = sendRecvInfo.sendRecvChannels_.rxChannel_;
u32 repeatNum = srcSlices.size();
CHK_RET(static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, sendChannel.handle, NOTIFY_IDX_ACK)));
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, recvChannel.handle, NOTIFY_IDX_ACK, execTimeout)));
for (int i = 0; i < repeatNum; i++) {
const DataSlice srcSlice = srcSlices[i];
const DataSlice dstSlice = dstSlices[i];
if (srcSlice.size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] SendRecvReadReduce: size is 0.");
continue;
}
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvReadReduce: src slice count [%u] is not mate to src slice size "
"[%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvReadReduce: dst slice count [%u] is not mate to dst slice size "
"[%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
TraceDataSlice("SendRecvReadReduce", "READ_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
srcSlice.count_, sendRecvInfo.dataType_, sendRecvInfo.reduceType_);
CHK_RET(static_cast<HcclResult>(HcommReadReduceOnThread(thread,
recvChannel.handle,
dst,
src,
srcSlice.count_,
static_cast<HcommDataType>(sendRecvInfo.dataType_),
static_cast<HcommReduceOp>(sendRecvInfo.reduceType_))));
}
CHK_RET(
static_cast<HcclResult>(HcommChannelNotifyRecordOnThread(thread, recvChannel.handle, NOTIFY_IDX_DATA_SIGNAL)));
CHK_RET(static_cast<HcclResult>(
HcommChannelNotifyWaitOnThread(thread, sendChannel.handle, NOTIFY_IDX_DATA_SIGNAL, execTimeout)));
return HCCL_SUCCESS;
}
HcclResult SendRecvBatchReadReduce(const SendRecvReduceInfo &sendRecvInfo, const ThreadHandle &thread)
{
auto processSlice = [&sendRecvInfo](int i, const DataSlice &srcSlice, const DataSlice &dstSlice,
std::vector<HcclHcommBatchTransferDesc> &transferDescs, u32 repeatNum) -> HcclResult {
CHK_PRT_RET(srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != srcSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvBatchReadReduce: src slice count [%u] is not mate to src slice "
"size [%u], dataType is [%d].",
srcSlice.count_,
srcSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(dstSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_] != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] SendRecvBatchReadReduce: dst slice count [%u] is not mate to dst slice "
"size [%u], dataType is [%d].",
dstSlice.count_,
dstSlice.size_,
sendRecvInfo.dataType_),
HcclResult::HCCL_E_INTERNAL);
void *dst = GetSliceAddr(dstSlice);
void *src = GetSliceAddr(srcSlice);
u64 len = srcSlice.count_ * DATATYPE_SIZE_TABLE[sendRecvInfo.dataType_];
TraceDataSlice("SendRecvBatchReadReduce", "BATCH_READ_REDUCE", i, repeatNum, srcSlice, dstSlice, src, dst,
len, sendRecvInfo.dataType_, sendRecvInfo.reduceType_);
transferDescs.push_back(MakeBatchReduceDesc(HCCL_HCOMM_TRANSFER_TYPE_READ_REDUCE, dst, src, srcSlice.count_,
sendRecvInfo.dataType_, sendRecvInfo.reduceType_));
return HCCL_SUCCESS;
};
return DoSendRecvBatchRx(sendRecvInfo, thread, "SendRecvBatchReadReduce", "BATCH_READ_REDUCE",
processSlice, SendRecvReadReduce);
}
HcclResult LocalCopy(const ThreadHandle &thread, const DataSlice &srcSlice, const DataSlice &dstSlice)
{
CHK_PRT_RET(srcSlice.size_ == 0,
HCCL_WARNING("[AlgDataTransWrapper] LocalCopy: src slice size is [%u].", srcSlice.size_),
HcclResult::HCCL_SUCCESS);
CHK_PRT_RET(srcSlice.size_ != dstSlice.size_,
HCCL_ERROR("[AlgDataTransWrapper] LocalCopy: src slice size [%u] is not equal to dst slice size [%u].",
srcSlice.size_,
dstSlice.size_),
HcclResult::HCCL_E_INTERNAL);
void *srcIn = GetSliceAddr(srcSlice);
void *dstOut = GetSliceAddr(dstSlice);
TraceDataSlice("LocalCopy", "LOCAL_COPY", 0, 1, srcSlice, dstSlice, srcIn, dstOut,
srcSlice.size_, HCCL_DATA_TYPE_RESERVED, HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommLocalCopyOnThread(thread, dstOut, srcIn, srcSlice.size_)));
return HCCL_SUCCESS;
}
HcclResult LocalReduce(const ThreadHandle &thread, const DataSlice &srcSlice, const DataSlice &dstSlice,
const HcclDataType dataType, const HcclReduceOp reduceOp)
{
if (dataType == HCCL_DATA_TYPE_INT64 || dataType == HCCL_DATA_TYPE_UINT64 || dataType == HCCL_DATA_TYPE_FP64 ||
reduceOp == HcclReduceOp::HCCL_REDUCE_PROD) {
CHK_RET(AicpuReduce(thread, srcSlice, dstSlice, dataType, reduceOp));
return HCCL_SUCCESS;
}
CHK_PRT_RET(srcSlice.size_ == 0,
HCCL_WARNING("[AlgDataTransWrapper] LocalReduce: src slice size is [%u].", srcSlice.size_),
HcclResult::HCCL_SUCCESS);
CHK_PRT_RET(srcSlice.size_ != dstSlice.size_,
HCCL_ERROR("[InsCollAlgFactory] LocalReduce: src slice size [%u] is not equal to dst slice size [%u].",
srcSlice.size_,
dstSlice.size_),
HcclResult::HCCL_E_INTERNAL);
void *src = GetSliceAddr(srcSlice);
void *dst = GetSliceAddr(dstSlice);
TraceDataSlice("LocalReduce", "LOCAL_REDUCE", 0, 1, srcSlice, dstSlice, src, dst,
srcSlice.count_, dataType, reduceOp);
CHK_RET(static_cast<HcclResult>(HcommLocalReduceOnThread(thread,
dst,
src,
srcSlice.count_,
static_cast<HcommDataType>(dataType),
static_cast<HcommReduceOp>(reduceOp))));
return HCCL_SUCCESS;
}
HcclResult LocalCopySlices(
const ThreadHandle &thread, const std::vector<DataSlice> &srcSlices, const std::vector<DataSlice> &dstSlices)
{
CHK_PRT_RET(srcSlices.size() != dstSlices.size(),
HCCL_ERROR("[InsCollAlgFactory] [AlgDataTrans] LocalCopySlices: num of src slices [%u], is not equal "
"to num of dst slices [%u].",
srcSlices.size(),
dstSlices.size()),
HcclResult::HCCL_E_INTERNAL);
DataSlice tmpSrcSlice = srcSlices[0];
DataSlice tmpDstSlice = dstSlices[0];
for (u32 sliceIdx = 0; sliceIdx < srcSlices.size(); sliceIdx++) {
if (srcSlices[sliceIdx].size_ == 0) {
HCCL_WARNING("[AlgDataTransWrapper] LocalCopySlices: size is 0.");
continue;
}
TraceDataSlice("LocalCopySlices", "LOCAL_COPY_SLICE", sliceIdx, srcSlices.size(),
srcSlices[sliceIdx], dstSlices[sliceIdx], GetSliceAddr(srcSlices[sliceIdx]),
GetSliceAddr(dstSlices[sliceIdx]), srcSlices[sliceIdx].size_, HCCL_DATA_TYPE_RESERVED,
HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_PRT_RET(srcSlices[sliceIdx].size_ != dstSlices[sliceIdx].size_,
HCCL_ERROR("[InsCollAlgFactory] [AlgDataTransWrapper] LocalCopySlices: [%u]-th slice, src slice size [%u] "
"is not equal to dst slice size [%u].",
sliceIdx,
srcSlices[sliceIdx].size_,
dstSlices[sliceIdx].size_),
HcclResult::HCCL_E_INTERNAL);
if (sliceIdx == (srcSlices.size() - 1)) {
void *src = GetSliceAddr(tmpSrcSlice);
void *dst = GetSliceAddr(tmpDstSlice);
TraceDataSlice("LocalCopySlices", "LOCAL_COPY_MERGED", sliceIdx, srcSlices.size(),
tmpSrcSlice, tmpDstSlice, src, dst, tmpSrcSlice.size_, HCCL_DATA_TYPE_RESERVED,
HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommLocalCopyOnThread(thread, dst, src, tmpSrcSlice.size_)));
} else if (IsContinuousSlice(srcSlices[sliceIdx + 1], tmpSrcSlice) &&
IsContinuousSlice(dstSlices[sliceIdx + 1], tmpDstSlice)) {
u64 newTmpSize = tmpSrcSlice.size_ + srcSlices[sliceIdx + 1].size_;
tmpSrcSlice = DataSlice(tmpSrcSlice.addr_, tmpSrcSlice.offset_, newTmpSize);
tmpDstSlice = DataSlice(tmpDstSlice.addr_, tmpDstSlice.offset_, newTmpSize);
} else {
void *src = GetSliceAddr(tmpSrcSlice);
void *dst = GetSliceAddr(tmpDstSlice);
TraceDataSlice("LocalCopySlices", "LOCAL_COPY_MERGED", sliceIdx, srcSlices.size(),
tmpSrcSlice, tmpDstSlice, src, dst, tmpSrcSlice.size_, HCCL_DATA_TYPE_RESERVED,
HcclReduceOp::HCCL_REDUCE_RESERVED);
CHK_RET(static_cast<HcclResult>(HcommLocalCopyOnThread(thread, dst, src, tmpSrcSlice.size_)));
tmpSrcSlice = srcSlices[sliceIdx + 1];
tmpDstSlice = dstSlices[sliceIdx + 1];
}
}
return HcclResult::HCCL_SUCCESS;
}
bool IsContinuousSlice(const DataSlice &nxtSlice, const DataSlice &currSlice)
{
if (nxtSlice.addr_ != currSlice.addr_) {
return false;
}
if (nxtSlice.offset_ != currSlice.offset_ + currSlice.size_) {
return false;
}
return true;
}
HcclResult PreSyncInterThreads(const ThreadHandle &mainThread, const std::vector<ThreadHandle> &subThreads,
const std::vector<u32> ¬ifyIdxMainToSub)
{
CHK_PRT_RET(subThreads.size() == 0 || notifyIdxMainToSub.size() == 0,
HCCL_ERROR("[AlgDataTransWrapper] [PreSyncInterThreads] subThreads size: [%u], notifyIdxMainToSub size [%u] "
"0 is not correct.",
subThreads.size(),
notifyIdxMainToSub.size()),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(subThreads.size() != notifyIdxMainToSub.size(),
HCCL_ERROR("[AlgDataTransWrapper] [PreSyncInterThreads] subThreads size: [%u], notifyIdxMainToSub size [%u] "
"is not equal.",
subThreads.size(),
notifyIdxMainToSub.size()),
HcclResult::HCCL_E_INTERNAL);
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
for (u32 tidx = 0; tidx < subThreads.size(); tidx++) {
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyRecordOnThread(mainThread, subThreads[tidx], notifyIdxMainToSub[tidx])));
}
for (u32 tidx = 0; tidx < subThreads.size(); tidx++) {
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyWaitOnThread(subThreads[tidx], notifyIdxMainToSub[tidx], execTimeout)));
}
return HcclResult::HCCL_SUCCESS;
}
HcclResult PostSyncInterThreads(const ThreadHandle &mainThread, const std::vector<ThreadHandle> &subThreads,
const std::vector<u32> ¬ifyIdxSubToMain)
{
CHK_PRT_RET(subThreads.size() == 0 || notifyIdxSubToMain.size() == 0,
HCCL_ERROR("[AlgDataTransWrapper] [PreSyncInterThreads] subThreads size: [%u], notifyIdxSubToMain size [%u] "
"0 is not correct.",
subThreads.size(),
notifyIdxSubToMain.size()),
HcclResult::HCCL_E_INTERNAL);
CHK_PRT_RET(subThreads.size() != notifyIdxSubToMain.size(),
HCCL_ERROR("[AlgDataTransWrapper] [PreSyncInterThreads] subThreads size: [%u], notifyIdxSubToMain size [%u] "
"is not equal.",
subThreads.size(),
notifyIdxSubToMain.size()),
HcclResult::HCCL_E_INTERNAL);
u32 execTimeout = ExecTimeoutManager::Instance().GetExecTimeout();
for (u32 tidx = 0; tidx < subThreads.size(); tidx++) {
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyWaitOnThread(mainThread, notifyIdxSubToMain[tidx], execTimeout)));
}
for (u32 tidx = 0; tidx < subThreads.size(); tidx++) {
CHK_RET(static_cast<HcclResult>(
HcommThreadNotifyRecordOnThread(subThreads[tidx], mainThread, notifyIdxSubToMain[tidx])));
}
return HcclResult::HCCL_SUCCESS;
}
float Fp16ToFp32(uint16_t fp16Bits)
{
uint32_t sign = (fp16Bits >> 15) & 0x1;
uint32_t exponent = (fp16Bits >> 10) & 0x1F;
uint32_t mantissa = fp16Bits & 0x3FF;
if (exponent == 0) {
if (mantissa == 0) {
uint32_t result = sign << 31;
float f;
memcpy_s(&f, sizeof(f), &result, sizeof(f));
return f;
}
int shift = 0;
while ((mantissa & 0x400) == 0) {
mantissa <<= 1;
shift++;
}
mantissa &= 0x3FF;
int32_t fp32Exp = 127 - 15 + 1 - shift;
if (fp32Exp <= 0) {
uint32_t result = sign << 31;
float f;
memcpy_s(&f, sizeof(f), &result, sizeof(f));
return f;
}
uint32_t result = (sign << 31) | (static_cast<uint32_t>(fp32Exp) << 23) | (mantissa << 13);
float f;
memcpy_s(&f, sizeof(f), &result, sizeof(f));
return f;
}
if (exponent == 0x1F) {
uint32_t result = (sign << 31) | (0xFF << 23) | (mantissa << 13);
float f;
memcpy_s(&f, sizeof(f), &result, sizeof(f));
return f;
}
uint32_t result = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
float f;
memcpy_s(&f, sizeof(f), &result, sizeof(f));
return f;
}
uint16_t Fp32DenormToFp16(uint32_t sign, uint32_t mantissa, int32_t fp16Exp)
{
if (fp16Exp < -10) {
return static_cast<uint16_t>(sign << 15);
}
mantissa |= 0x800000;
int32_t totalShift = 14 - fp16Exp;
uint32_t roundBit = (mantissa >> (totalShift - 1)) & 0x1;
uint32_t truncated = mantissa & ((1U << (totalShift - 1)) - 1);
uint32_t sticky = (truncated != 0) ? 1 : 0;
uint16_t fp16Mant = static_cast<uint16_t>(mantissa >> totalShift);
fp16Mant += roundBit && (sticky || (fp16Mant & 0x1));
if (fp16Mant & 0x400) {
return static_cast<uint16_t>((sign << 15) | (1 << 10));
}
return static_cast<uint16_t>((sign << 15) | fp16Mant);
}
uint16_t Fp32ToFp16(float value)
{
uint32_t fp32Bits;
memcpy_s(&fp32Bits, sizeof(fp32Bits), &value, sizeof(fp32Bits));
uint32_t sign = (fp32Bits >> 31) & 0x1;
uint32_t exponent = (fp32Bits >> 23) & 0xFF;
uint32_t mantissa = fp32Bits & 0x7FFFFF;
if (exponent == 0) {
return static_cast<uint16_t>(sign << 15);
} else if (exponent == 0xFF) {
if (mantissa == 0) {
return static_cast<uint16_t>((sign << 15) | 0x7C00);
}
uint16_t fp16Mant = static_cast<uint16_t>(mantissa >> 13);
if (fp16Mant == 0) {
fp16Mant = 1;
}
return static_cast<uint16_t>((sign << 15) | 0x7C00 | fp16Mant);
}
int32_t fp16Exp = static_cast<int32_t>(exponent) - 127 + 15;
if (fp16Exp >= 31) {
return static_cast<uint16_t>((sign << 15) | 0x7C00);
} else if (fp16Exp <= 0) {
return Fp32DenormToFp16(sign, mantissa, fp16Exp);
}
uint32_t discarded = mantissa & 0x1FFF;
uint16_t fp16Mant = static_cast<uint16_t>(mantissa >> 13);
uint32_t roundBit = (discarded >> 12) & 0x1;
uint32_t sticky = (discarded & 0xFFF) ? 1 : 0;
fp16Mant += roundBit && (sticky || (fp16Mant & 0x1));
if (fp16Mant == 0x400) {
fp16Mant = 0;
fp16Exp++;
}
if (fp16Exp >= 31) {
return static_cast<uint16_t>((sign << 15) | 0x7C00);
}
return static_cast<uint16_t>((sign << 15) | (fp16Exp << 10) | fp16Mant);
}
HcclResult AicpuReduceFp16(u8 *dst, u8 *src, u64 size, const HcclReduceOp reduceOp)
{
u64 count = size / sizeof(uint16_t);
std::vector<float> srcFp32(count);
std::vector<float> dstFp32(count);
uint16_t *srcFp16 = reinterpret_cast<uint16_t *>(src);
uint16_t *dstFp16 = reinterpret_cast<uint16_t *>(dst);
for (u64 i = 0; i < count; ++i) {
srcFp32[i] = Fp16ToFp32(srcFp16[i]);
dstFp32[i] = Fp16ToFp32(dstFp16[i]);
}
HcclResult ret = AicpuReduceTemplate<float>(dstFp32.data(),
dstFp32.size() * sizeof(float),
srcFp32.data(),
srcFp32.size() * sizeof(float),
reduceOp);
CHK_PRT_RET(ret != HcclResult::HCCL_SUCCESS,
HCCL_ERROR("[AicpuReduceFp16] AicpuReduceTemplate failed, ret[%d].", static_cast<int>(ret)),
ret);
for (u64 i = 0; i < count; ++i) {
dstFp16[i] = Fp32ToFp16(dstFp32[i]);
}
return ret;
}
HcclResult AicpuReduce(const ThreadHandle &thread, const DataSlice &srcSlice, const DataSlice &dstSlice,
const HcclDataType dataType, const HcclReduceOp reduceOp)
{
(void) thread;
CHK_PRT_RET(srcSlice.size_ != dstSlice.size_, HCCL_ERROR("[AlgDataTransWrapper] [AicpuReduce] AicpuReduce: src slice size [%u] "\
"is not equal to dst slice size [%u].", srcSlice.size_, dstSlice.size_), HcclResult::HCCL_E_INTERNAL);
auto ret = HcclResult::HCCL_SUCCESS;
u8 *src = static_cast<u8 *>(GetSliceAddr(srcSlice));
u8 *dst = static_cast<u8 *>(GetSliceAddr(dstSlice));
TraceDataSlice("AicpuReduce", "AICPU_REDUCE", 0, 1, srcSlice, dstSlice, src, dst, srcSlice.size_, dataType, reduceOp);
switch (dataType) {
case HcclDataType::HCCL_DATA_TYPE_INT8:
ret = AicpuReduceTemplate<int8_t>(reinterpret_cast<int8_t *>(dst), dstSlice.size_,
reinterpret_cast<int8_t *>(src), srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_INT32:
ret = AicpuReduceTemplate<int32_t>(reinterpret_cast<int32_t *>(dst), dstSlice.size_,
reinterpret_cast<int32_t *>(src), srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_FP16:
ret = AicpuReduceFp16(dst, src, srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_FP32:
ret = AicpuReduceTemplate<float>(reinterpret_cast<float *>(dst), dstSlice.size_,
reinterpret_cast<float *>(src), srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_INT64:
ret = AicpuReduceTemplate<int64_t>(reinterpret_cast<int64_t *>(dst), dstSlice.size_,
reinterpret_cast<int64_t *>(src), srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_UINT64:
ret = AicpuReduceTemplate<uint64_t>(reinterpret_cast<uint64_t *>(dst), dstSlice.size_,
reinterpret_cast<uint64_t *>(src), srcSlice.size_, reduceOp);
break;
case HcclDataType::HCCL_DATA_TYPE_FP64:
ret = AicpuReduceTemplate<double>(reinterpret_cast<double *>(dst), dstSlice.size_,
reinterpret_cast<double *>(src), srcSlice.size_, reduceOp);
break;
default:
HCCL_ERROR("DataType[%d] not support", int(dataType));
ret = HCCL_E_INTERNAL;
break;
}
return ret;
}
template <typename T>
typename std::enable_if<std::is_same<typename WiderType<T>::Type, T>::value, T>::type
SaturatedAdd(T a, T b)
{
return a + b;
}
template <typename T>
typename std::enable_if<!std::is_same<typename WiderType<T>::Type, T>::value, T>::type
SaturatedAdd(T a, T b)
{
using W = typename WiderType<T>::Type;
W result = static_cast<W>(a) + static_cast<W>(b);
if (result > static_cast<W>(std::numeric_limits<T>::max())) {
return std::numeric_limits<T>::max();
}
if (result < static_cast<W>(std::numeric_limits<T>::min())) {
return std::numeric_limits<T>::min();
}
return static_cast<T>(result);
}
template <typename T>
typename std::enable_if<std::is_same<typename WiderType<T>::Type, T>::value, T>::type
SaturatedMul(T a, T b)
{
return a * b;
}
template <typename T>
typename std::enable_if<!std::is_same<typename WiderType<T>::Type, T>::value, T>::type
SaturatedMul(T a, T b)
{
using W = typename WiderType<T>::Type;
W result = static_cast<W>(a) * static_cast<W>(b);
if (result > static_cast<W>(std::numeric_limits<T>::max())) {
return std::numeric_limits<T>::max();
}
if (result < static_cast<W>(std::numeric_limits<T>::min())) {
return std::numeric_limits<T>::min();
}
return static_cast<T>(result);
}
template <typename T>
HcclResult AicpuReduceTemplate(T *dst, u64 dstSize, T *src, u64 srcSize, const HcclReduceOp reduceOp)
{
if (dstSize != srcSize) {
HCCL_ERROR("srcSize[%llu] should be equal to dstSize[%llu]", srcSize, dstSize);
return HcclResult::HCCL_E_INTERNAL;
}
auto ret = HcclResult::HCCL_SUCCESS;
u64 count = dstSize / u64(sizeof(T));
for (u64 i = 0; i < count; ++i) {
T dstData = *(dst + i);
T srcData = *(src + i);
switch (reduceOp) {
case HcclReduceOp::HCCL_REDUCE_SUM:
*(dst + i) = SaturatedAdd(srcData, dstData);
break;
case HcclReduceOp::HCCL_REDUCE_PROD:
*(dst + i) = SaturatedMul(srcData, dstData);
break;
case HcclReduceOp::HCCL_REDUCE_MAX:
*(dst + i) = std::max(srcData, dstData);
break;
case HcclReduceOp::HCCL_REDUCE_MIN:
*(dst + i) = std::min(srcData, dstData);
break;
default:
HCCL_ERROR("ReduceOp[%d] not support", int(reduceOp));
ret = HcclResult::HCCL_E_INTERNAL;
break;
}
}
return ret;
}
}