* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ALG_V2_TEMPLATE_UTILS
#define ALG_V2_TEMPLATE_UTILS
#include <vector>
#include <memory>
#include <string>
#include <sstream>
#include <list>
#include "alg_param.h"
#include "binary_stream.h"
HcclResult __attribute__((weak)) HcommThreadJoin(ThreadHandle thread, uint32_t timeout);
namespace ops_hccl {
# define UINT32_MAX (4294967295U)
constexpr u32 INVALID_U32 = UINT32_MAX;
constexpr s32 INVALID_RANKID = INT32_MAX;
struct SliceInfo {
u64 offset{0};
u64 size{0};
};
using RankSliceInfo = std::vector<std::vector<SliceInfo>>;
enum class BufferType {
INPUT = 0,
OUTPUT = 1,
HCCL_BUFFER = 2,
DEFAULT
};
enum class BatchSendRecvOpType {
RECORD = 0,
SEND = 1,
RECV = 2,
FENCE = 3,
DEFAULT
};
struct DataSlice {
void* addr_ = nullptr;
u64 offset_{0};
u64 size_{0};
u64 count_{0};
DataSlice(void* addr, u64 offset, u64 size, u64 count)
: addr_(addr), offset_(offset), size_(size), count_(count)
{
}
DataSlice(void* addr, u64 offset, u64 size)
: addr_(addr), offset_(offset), size_(size)
{
count_ = 0;
}
std::string Describe() const {
std::ostringstream oss;
oss << "DataSlice: addr=" << addr_
<< ", offset=" << offset_
<< ", size=" << size_
<< ", count=" << count_;
return oss.str();
}
};
struct SlicesList {
std::vector<DataSlice> srcSlices_;
std::vector<DataSlice> dstSlices_;
SlicesList(const std::vector<DataSlice> &srcSlices, const std::vector<DataSlice> &dstSlices)
: srcSlices_(srcSlices), dstSlices_(dstSlices)
{
}
};
struct A2ASendRecvInfo {
std::vector<u64> sendLength;
std::vector<u64> sendOffset;
std::vector<u64> recvLength;
std::vector<u64> recvOffset;
std::vector<u64> sendCounts;
std::vector<u64> sendDispls;
std::vector<u64> recvCounts;
std::vector<u64> recvDispls;
};
struct DataInfo {
ChannelInfo channel_;
SlicesList slices_;
HcclDataType dataType_;
DataInfo(const ChannelInfo &channel, const SlicesList &slices)
: channel_(channel), slices_(slices)
{
}
DataInfo(const ChannelInfo &channel, const SlicesList &slices, HcclDataType dataType)
: channel_(channel), slices_(slices), dataType_(dataType)
{
}
};
struct DataReduceInfo {
ChannelInfo channel_;
SlicesList slices_;
HcclDataType dataType_;
HcclReduceOp reduceType_;
DataReduceInfo(const ChannelInfo &channel, const SlicesList &slices,
HcclDataType dataType, HcclReduceOp reduceType)
: channel_(channel), slices_(slices), dataType_(dataType), reduceType_(reduceType)
{
}
};
struct TxRxChannels {
ChannelInfo txChannel_;
ChannelInfo rxChannel_;
TxRxChannels(const ChannelInfo &txLink, const ChannelInfo &rxLink) : txChannel_(txLink), rxChannel_(rxLink)
{
}
};
struct TxRxSlicesList {
SlicesList txSlicesList_;
SlicesList rxSlicesList_;
TxRxSlicesList(const SlicesList &txSlicesList, const SlicesList &rxSlicesList)
: txSlicesList_(txSlicesList), rxSlicesList_(rxSlicesList)
{
}
};
struct SendRecvInfo {
TxRxChannels sendRecvChannels_;
TxRxSlicesList sendRecvSlices_;
HcclDataType dataType_;
SendRecvInfo(const TxRxChannels &sendRecvLinks, const TxRxSlicesList &sendRecvSlices)
: sendRecvChannels_(sendRecvLinks), sendRecvSlices_(sendRecvSlices)
{
}
SendRecvInfo(const TxRxChannels &sendRecvLinks, const TxRxSlicesList &sendRecvSlices, HcclDataType dataType)
: sendRecvChannels_(sendRecvLinks), sendRecvSlices_(sendRecvSlices), dataType_(dataType)
{
}
};
struct SendRecvReduceInfo {
TxRxChannels sendRecvChannels_;
TxRxSlicesList sendRecvSlices_;
HcclDataType dataType_;
HcclReduceOp reduceType_;
SendRecvReduceInfo(const TxRxChannels &sendRecvLinks, const TxRxSlicesList &sendRecvSlices,
const HcclDataType dataType, const HcclReduceOp reduceOp)
: sendRecvChannels_(sendRecvLinks), sendRecvSlices_(sendRecvSlices), dataType_(dataType), reduceType_(reduceOp)
{
}
};
struct BuffInfo {
void* inputPtr = nullptr;
void* outputPtr = nullptr;
HcclMem hcclBuff;
BufferType inBuffType;
BufferType outBuffType;
BufferType hcclBuffType;
u64 inputSize = 0;
u64 outputSize = 0;
u64 hcclBuffSize = 0;
u64 inBuffBaseOff = 0;
u64 outBuffBaseOff = 0;
u64 hcclBuffBaseOff = 0;
};
struct StepSliceInfo
{
BuffInfo buffInfo;
std::vector<std::vector<u64>> stepCount;
std::vector<std::vector<u64>> stepSliceSize;
std::vector<u64> stepInputSliceStride;
std::vector<u64> stepOutputSliceStride;
std::vector<std::vector<u64>> inputOmniPipeSliceStride;
std::vector<std::vector<u64>> outputOmniPipeSliceStride;
std::vector<char> Serialize() const
{
BinaryStream binaryStream;
binaryStream << stepCount;
binaryStream << stepSliceSize;
binaryStream << stepInputSliceStride;
binaryStream << stepOutputSliceStride;
binaryStream << inputOmniPipeSliceStride;
binaryStream << outputOmniPipeSliceStride;
std::vector<char> result;
binaryStream.Dump(result);
return result;
}
void DeSerialize(std::vector<char> &data)
{
BinaryStream binaryStream(data);
binaryStream >> stepCount;
binaryStream >> stepSliceSize;
binaryStream >> stepInputSliceStride;
binaryStream >> stepOutputSliceStride;
binaryStream >> inputOmniPipeSliceStride;
binaryStream >> outputOmniPipeSliceStride;
}
};
struct TemplateFastLaunchCtx {
BuffInfo buffInfo;
std::vector<ThreadHandle> threads;
std::vector<CcuKernelSubmitInfo> ccuKernelSubmitInfos;
};
struct TemplateDataParams {
BuffInfo buffInfo;
u64 count{0};
u64 sliceSize{0};
u64 inputSliceStride{0};
u64 outputSliceStride{0};
u64 repeatNum{0};
u64 inputRepeatStride{0};
u64 outputRepeatStride{0};
u64 tailSize{0};
bool enableRemoteMemAccess{false};
u64 processedDataCount{0};
u64 root{0};
HcclDataType dataType{HCCL_DATA_TYPE_INT8};
std::vector<u64> allRankSliceSize;
std::vector<u64> allRankDispls;
std::vector<u64> allRankProcessedDataCount;
std::vector<u64> sendCounts;
std::vector<u64> recvCounts;
std::vector<u64> sdispls;
std::vector<u64> rdispls;
StepSliceInfo stepSliceInfo;
BatchSendRecvOpType opType;
std::vector<char> Serialize() const
{
BinaryStream binaryStream;
binaryStream << buffInfo;
binaryStream << count;
binaryStream << sliceSize;
binaryStream << inputSliceStride;
binaryStream << outputSliceStride;
binaryStream << repeatNum;
binaryStream << inputRepeatStride;
binaryStream << outputRepeatStride;
binaryStream << tailSize;
binaryStream << enableRemoteMemAccess;
binaryStream << allRankSliceSize;
binaryStream << allRankDispls;
binaryStream << sendCounts;
binaryStream << recvCounts;
binaryStream << sdispls;
binaryStream << rdispls;
binaryStream << allRankProcessedDataCount;
binaryStream << root;
binaryStream << dataType;
binaryStream << stepSliceInfo.Serialize();
binaryStream << opType;
std::vector<char> result;
binaryStream.Dump(result);
return result;
}
void DeSerialize(std::vector<char> &data)
{
BinaryStream binaryStream(data);
binaryStream >> buffInfo;
binaryStream >> count;
binaryStream >> sliceSize;
binaryStream >> inputSliceStride;
binaryStream >> outputSliceStride;
binaryStream >> repeatNum;
binaryStream >> inputRepeatStride;
binaryStream >> outputRepeatStride;
binaryStream >> tailSize;
binaryStream >> enableRemoteMemAccess;
binaryStream >> allRankSliceSize;
binaryStream >> allRankDispls;
binaryStream >> sendCounts;
binaryStream >> recvCounts;
binaryStream >> sdispls;
binaryStream >> rdispls;
binaryStream >> allRankProcessedDataCount;
binaryStream >> root;
binaryStream >> dataType;
std::vector<char> stepSliceInfoData;
binaryStream >> stepSliceInfoData;
stepSliceInfo.DeSerialize(stepSliceInfoData);
binaryStream >> opType;
}
};
struct TemplateResource {
std::map<u32, std::vector<ChannelInfo>> channels;
std::vector<ThreadHandle> threads;
std::vector<CcuKernelHandle> ccuKernels;
std::vector<CcuKernelSubmitInfo> submitInfos;
void *npu2DpuShmemPtr;
void *dpu2NpuShmemPtr;
void* aivCommInfoPtr = nullptr;
};
struct DPURunInfo {
std::string templateName;
TemplateDataParams tempAlgParams;
std::map<uint32_t, std::vector<ChannelInfo>> channels;
u32 myRank;
std::vector<std::vector<uint32_t>> subCommRanks;
std::vector<char> Serialize() const
{
BinaryStream binaryStream;
binaryStream << templateName;
binaryStream << tempAlgParams.Serialize();
binaryStream << channels;
binaryStream << myRank;
binaryStream << subCommRanks;
std::vector<char> result;
binaryStream.Dump(result);
return result;
}
void DeSerialize(std::vector<char> &data)
{
BinaryStream binaryStream(data);
binaryStream >> templateName;
std::vector<char> tempAlgParamsData;
binaryStream >> tempAlgParamsData;
tempAlgParams.DeSerialize(tempAlgParamsData);
binaryStream >> channels;
binaryStream >> myRank;
binaryStream >> subCommRanks;
}
};
struct AlltoAllSendRecvInfo {
std::vector<u64> sendCounts;
std::vector<u64> sendDispls;
std::vector<u64> recvCounts;
std::vector<u64> recvDispls;
};
struct AicpuNHRStepInfo {
u32 step = 0;
u32 myRank = 0;
u32 nSlices;
u32 toRank = 0;
u32 fromRank = 0;
std::vector<u32> txSliceIdxs;
std::vector<u32> rxSliceIdxs;
AicpuNHRStepInfo() : nSlices(0)
{
}
};
HcclResult GetAlgRank(const u32 virtRank, const std::vector<u32> &rankIds, u32 &algRank);
u32 GetNHRStepNum(u32 rankSize);
inline u32 CalcChannelsPerRank(const std::vector<HcclChannelDesc> &channels)
{
u32 channelsPerRank = 1;
u32 currentRank = INVALID_VALUE_RANKID;
u32 currentCount = 0;
u32 changeNum = 0;
for (const auto &channel : channels) {
if (channel.remoteRank == currentRank) {
currentCount++;
} else {
if (currentCount != channelsPerRank && currentRank != channels[0].remoteRank && currentRank != INVALID_VALUE_RANKID) {
HCCL_WARNING("[CalcChannelsPerRank] channel num[%u] of remote rank[%u] is not equal to "\
"channel num[%u] of previous ranks.",
currentCount, currentRank, channelsPerRank);
}
if (currentCount > channelsPerRank) {
channelsPerRank = currentCount;
}
currentRank = channel.remoteRank;
currentCount = 1;
}
}
if (currentCount > channelsPerRank) {
channelsPerRank = currentCount;
}
return channelsPerRank;
}
inline u32 CalcChannelsPerRank(const std::map<u32, std::vector<ChannelInfo>> &channels)
{
u32 channelsPerRank = 1;
for (const auto &channelsByRank : channels) {
if (channelsByRank.second.size() > channelsPerRank) {
channelsPerRank = static_cast<u32>(channelsByRank.second.size());
}
}
return channelsPerRank;
}
inline u64 RoundUp(const u64 dividend, const u64 divisor)
{
if (divisor == 0) {
HCCL_WARNING("[RoundUp] divisor is 0.");
return dividend;
}
return dividend / divisor + ((dividend % divisor != 0) ? 1 : 0);
}
template <typename... Args>
HcclResult FillCachedArgs(CcuKernelSubmitInfo &info, Args... args)
{
size_t argNum = sizeof...(Args);
if (UNLIKELY(argNum > CCU_MAX_TASK_ARG_NUM)) {
HCCL_ERROR("[FillCachedArgs] argNum is bigger than CCU_MAX_TASK_ARG_NUM[%d]", CCU_MAX_TASK_ARG_NUM);
return HcclResult::HCCL_E_INTERNAL;
}
uint64_t temp[] = { static_cast<uint64_t>(args)... };
for (size_t i = 0; i < argNum; i++) {
info.cachedArgs[i] = temp[i];
}
return HcclResult::HCCL_SUCCESS;
}
HcclResult CalcDataSplitByPortGroupCommon(const u64 totalDataCount,
const u64 dataTypeSize,
const std::vector<ChannelInfo> &channels,
std::vector<u64> &elemCountOut,
std::vector<u64> &sizeOut,
std::vector<u64> &elemOffset,
const u32 channelsPerRank);
HcclResult CalcDataSplitByPortGroupZAxisDetour(const u64 totalDataCount,
const u64 dataTypeSize,
const std::vector<ChannelInfo> &channels,
std::vector<u64> &elemCountOut,
std::vector<u64> &sizeOut,
std::vector<u64> &elemOffset,
const u32 level0ChannelNumPerRank,
const u32 level1ChannelNumPerRank,
const float level0DataRatio = 0.5f);
}
#endif