* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef OPS_HCCL_SRC_OPS_INC_COLL_ALG_PARAM
#define OPS_HCCL_SRC_OPS_INC_COLL_ALG_PARAM
#include <string>
#include <vector>
#include <map>
#include <set>
#include <unordered_set>
#include <memory>
#include <functional>
#include <functional>
#include <memory>
#include <hccl/hccl_comm.h>
#include "hccl_common.h"
#include "hccl_types.h"
#include "alg_type.h"
#include "hccl_res_dl.h"
#include "hcomm_primitives_dl.h"
#include "hccl_rank_graph_dl.h"
#include "hccl_host_comm_dl.h"
#include "binary_stream.h"
#include "hccl_ccu_res_dl.h"
#include "ccu_types_dl.h"
namespace ops_hccl {
constexpr uint64_t UB_MAX_DATA_SIZE = 256*1024*1024;
constexpr u32 MAX_NUM_BLOCKS = 56;
constexpr uint32_t DATATYPE_SIZE_TABLE[HCCL_DATA_TYPE_RESERVED] = {sizeof(int8_t), sizeof(int16_t), sizeof(int32_t),
2, sizeof(float), sizeof(int64_t), sizeof(uint64_t), sizeof(uint8_t), sizeof(uint16_t), sizeof(uint32_t),
8, 2, 16, 2, 1, 1, 1, 1};
constexpr u32 COMM_INDENTIFIER_MAX_LENGTH = 128;
constexpr uint32_t OP_NAME_LENGTH = 32;
constexpr uint32_t TAG_LENGTH = OP_NAME_LENGTH + COMM_INDENTIFIER_MAX_LENGTH;
constexpr uint32_t OP_ALG_LENGTH = 128;
constexpr uint32_t ALG_TAG_LENGTH = TAG_LENGTH + OP_ALG_LENGTH;
constexpr uint32_t MAX_TAG_LENGTH = 255;
constexpr uint32_t AICPU_CONTROL_NOTIFY_NUM = 2;
constexpr uint32_t MAX_MEM_TAG_LENGTH = OP_ALG_LENGTH + 32;
constexpr uint32_t RES_PACK_TAG_LENGTH = 255;
constexpr uint32_t MAX_TEMP_NUM_IN_ALGO = 8;
constexpr u32 LOCAL_NOTIFY_IDX_ZERO = 0;
constexpr u32 NOTIFY_IDX_ACK = 0;
constexpr u32 NOTIFY_IDX_DATA_SIGNAL = 1;
constexpr u32 NOTIFY_IDX_FIN_ACK = 2;
constexpr u32 CUSTOM_TIMEOUT = 1836;
constexpr u32 TIME_S_TO_US = 1000000;
constexpr u32 MAX_LENGTH = 128;
constexpr u32 ALG_MAX_LENGTH = 128;
constexpr u64 ALL_TO_ALL_V_VECTOR_NUM = 4;
constexpr u64 REDUCE_SCATTER_V_VECTOR_NUM = 2;
constexpr u64 ALL_GATHER_V_VECTOR_NUM = 2;
constexpr uint64_t GE_PARALLEL = 36;
constexpr uint64_t AICPU_ALIGN_SIZE = 4096;
constexpr u32 MESH_CHANNELS_NUM = 1;
constexpr uint64_t CCU_MAX_RANK_SIZE = 16;
enum class TopoType {
TOPO_TYPE_COMMON = 0,
TOPO_TYPE_8P_RING = 1,
TOPO_TYPE_4P_MESH = 2,
TOPO_TYPE_2P_MESH = 3,
TOPO_TYPE_1P_MESH = 4,
TOPO_TYPE_4P_RING = 5,
TOPO_TYPE_NP_SINGLE_RING = 6,
TOPO_TYPE_8P_MESH = 7,
TOPO_TYPE_NP_MESH = 8,
TOPO_TYPE_NP_DOUBLE_RING = 9,
TOPO_TYPE_HETEROG = 10,
TOPO_TYPE_ES_MESH = 11,
TOPO_TYPE_RESERVED
};
enum class OpExecuteConfig {
DEFAULT = 0,
HOSTCPU_TS = 1,
AICPU_TS = 2,
AIV = 3,
AIV_ONLY = 4,
CCU_MS = 5,
CCU_SCHED = 6,
AICPU = 7,
HOSTCPU = 8,
CCU_FAIL
};
enum class OpMode {
OFFLOAD = 0,
OPBASE = 1
};
enum class Level0Shape {
CLOS = 0,
MESH_1D = 1,
MESH_1D_CLOS = 2,
};
enum class Level0MeshType {
NOT_MESH = 0,
SINGLE_DIE = 1,
TWO_DIE_REGULAR = 2,
TWO_DIE_NOT_REGULAR = 3,
};
struct NetLayerDetails {
u32 netLayerNum;
std::vector<u32> netLayers;
std::vector<u32> netInstNumOfLayer;
std::vector<std::vector<u32>> instSizeListOfLayer;
std::vector<u32> localNetInsSizeOfLayer;
};
struct TopoInstDetails {
u32 topoInstNum;
std::vector<u32> sizeOfTopo;
std::vector<CommTopo> typeOfTopo;
std::vector<std::vector<u32>> ranksInTopo;
std::map<CommTopo, std::vector<u32>> rankNumForTopoType;
};
struct TopoInfo {
u32 userRank;
u32 userRankSize;
u32 serverIdx = INVALID_UINT;
u32 superPodIdx = INVALID_UINT;
DevType deviceType = DevType::DEV_TYPE_COUNT;
u32 deviceNumPerModule = 0;
u32 serverNumPerSuperPod = 0;
u32 serverNum = 0;
u32 moduleNum = 0;
u32 superPodNum = 0;
u32 moduleIdx = INVALID_UINT;
bool isDiffDeviceModule = false;
bool multiModuleDiffDeviceNumMode = false;
bool multiSuperPodDiffServerNumMode = false;
bool isHCCSSWNumEqualToTwiceSIONum = false;
ThreadHandle mainThread;
u32 notifyNumOnMainThread = 0;
};
struct TopoInfoWithNetLayerDetails : public TopoInfo {
u32 topoLevelNums = 0;
Level0Shape level0Topo;
bool Level0Nhr{false};
bool Level1Nhr{false};
bool Level1Hd{false};
bool is2DieFullMesh{false};
bool level0PcieMix{false};
bool level0BigClosRange{false};
u32 topoInstDetailsOfLayerSize = 0;
Level0MeshType level0MeshType;
NetLayerDetails netLayerDetails;
std::vector<TopoInstDetails> topoInstDetailsOfLayer;
std::vector<char> Serialize()
{
BinaryStream binaryStream;
binaryStream << userRank;
binaryStream << userRankSize;
binaryStream << serverIdx;
binaryStream << superPodIdx;
binaryStream << deviceType;
binaryStream << deviceNumPerModule;
binaryStream << serverNumPerSuperPod;
binaryStream << serverNum;
binaryStream << moduleNum;
binaryStream << superPodNum;
binaryStream << moduleIdx;
binaryStream << isDiffDeviceModule;
binaryStream << multiModuleDiffDeviceNumMode;
binaryStream << multiSuperPodDiffServerNumMode;
binaryStream << isHCCSSWNumEqualToTwiceSIONum;
binaryStream << mainThread;
binaryStream << notifyNumOnMainThread;
binaryStream << topoLevelNums;
binaryStream << level0Topo;
binaryStream << Level0Nhr;
binaryStream << Level1Nhr;
binaryStream << Level1Hd;
binaryStream << is2DieFullMesh;
binaryStream << level0PcieMix;
binaryStream << level0BigClosRange;
binaryStream << topoInstDetailsOfLayerSize;
binaryStream << level0MeshType;
binaryStream << netLayerDetails.netLayerNum;
binaryStream << netLayerDetails.netLayers;
binaryStream << netLayerDetails.netInstNumOfLayer;
binaryStream << netLayerDetails.instSizeListOfLayer;
binaryStream << netLayerDetails.localNetInsSizeOfLayer;
for (uint32_t idx = 0; idx < topoInstDetailsOfLayerSize; idx++) {
binaryStream << topoInstDetailsOfLayer[idx].topoInstNum;
binaryStream << topoInstDetailsOfLayer[idx].sizeOfTopo;
binaryStream << topoInstDetailsOfLayer[idx].typeOfTopo;
binaryStream << topoInstDetailsOfLayer[idx].ranksInTopo;
binaryStream << topoInstDetailsOfLayer[idx].rankNumForTopoType;
}
std::vector<char> result;
binaryStream.Dump(result);
return result;
}
void DeSerialize(std::vector<char> &data)
{
BinaryStream binaryStream(data);
binaryStream >> userRank;
binaryStream >> userRankSize;
binaryStream >> serverIdx;
binaryStream >> superPodIdx;
binaryStream >> deviceType;
binaryStream >> deviceNumPerModule;
binaryStream >> serverNumPerSuperPod;
binaryStream >> serverNum;
binaryStream >> moduleNum;
binaryStream >> superPodNum;
binaryStream >> moduleIdx;
binaryStream >> isDiffDeviceModule;
binaryStream >> multiModuleDiffDeviceNumMode;
binaryStream >> multiSuperPodDiffServerNumMode;
binaryStream >> isHCCSSWNumEqualToTwiceSIONum;
binaryStream >> mainThread;
binaryStream >> notifyNumOnMainThread;
binaryStream >> topoLevelNums;
binaryStream >> level0Topo;
binaryStream >> Level0Nhr;
binaryStream >> Level1Nhr;
binaryStream >> Level1Hd;
binaryStream >> is2DieFullMesh;
binaryStream >> level0PcieMix;
binaryStream >> level0BigClosRange;
binaryStream >> topoInstDetailsOfLayerSize;
binaryStream >> level0MeshType;
binaryStream >> netLayerDetails.netLayerNum;
binaryStream >> netLayerDetails.netLayers;
binaryStream >> netLayerDetails.netInstNumOfLayer;
binaryStream >> netLayerDetails.instSizeListOfLayer;
binaryStream >> netLayerDetails.localNetInsSizeOfLayer;
topoInstDetailsOfLayer.resize(topoInstDetailsOfLayerSize);
for (uint32_t idx = 0; idx < topoInstDetailsOfLayerSize; idx++) {
binaryStream >> topoInstDetailsOfLayer[idx].topoInstNum;
binaryStream >> topoInstDetailsOfLayer[idx].sizeOfTopo;
binaryStream >> topoInstDetailsOfLayer[idx].typeOfTopo;
binaryStream >> topoInstDetailsOfLayer[idx].ranksInTopo;
binaryStream >> topoInstDetailsOfLayer[idx].rankNumForTopoType;
}
}
};
struct CcuKernelArgBase {
ChannelHandle channels[CCU_MAX_RANK_SIZE];
uint32_t channelCount;
};
struct CcuKernelInfo {
u32 resGroup = 0;
char kernelFuncName[64];
void* kernelFunc;
void *kernelArg;
std::vector<HcclChannelDesc> channels;
private:
std::shared_ptr<CcuKernelArgBase> kernelArgSmartPtr;
public:
template<typename T>
void setKernelArg(std::shared_ptr<T> arg) {
kernelArgSmartPtr = std::static_pointer_cast<CcuKernelArgBase>(arg);
kernelArg = static_cast<void*>(arg.get());
}
};
#define CCU_MAX_TASK_ARG_NUM 48
struct CcuKernelSubmitInfo {
CcuKernelHandle kernelHandle;
uint64_t cachedArgs[CCU_MAX_TASK_ARG_NUM];
};
struct CcuFastLaunchCtx {
char algName[OP_ALG_LENGTH];
u32 notifyNumOnMainThread = 0;
u32 threadNum;
u32 ccuKernelNum[MAX_TEMP_NUM_IN_ALGO];
ThreadHandle *GetThreadHandlePtr() const
{
size_t offset = offsetof(CcuFastLaunchCtx, ccuKernelNum)
+ sizeof(u32) * MAX_TEMP_NUM_IN_ALGO;
return reinterpret_cast<ThreadHandle*>(
reinterpret_cast<char*>(const_cast<CcuFastLaunchCtx*>(this)) + offset
);
}
CcuKernelSubmitInfo *GetCcuKernelSubmitInfoPtr() const
{
size_t offset = offsetof(CcuFastLaunchCtx, ccuKernelNum)
+ sizeof(u32) * MAX_TEMP_NUM_IN_ALGO
+ sizeof(ThreadHandle) * threadNum;
return reinterpret_cast<CcuKernelSubmitInfo*>(
reinterpret_cast<char*>(const_cast<CcuFastLaunchCtx*>(this)) + offset
);
}
static u64 GetCtxSize(u32 threadNum, u32 totalCcuKernelNum)
{
return sizeof(CcuFastLaunchCtx)
+ sizeof(ThreadHandle) * threadNum
+ sizeof(CcuKernelSubmitInfo) * totalCcuKernelNum;
}
};
struct AlgResourceRequest {
u32 notifyNumOnMainThread = 0;
u32 slaveThreadNum = 0;
std::vector<u32> notifyNumPerThread;
std::vector<std::vector<HcclChannelDesc>> channels;
std::vector<CcuKernelInfo> ccuKernelInfos;
std::vector<u32> ccuKernelNum;
};
constexpr u32 HCCL_LOGIC_TOPO_LEVEL_NUM = 4;
struct SubCommInfo {
u32 localRank = 0;
u32 localRankSize = 1;
};
struct AlgHierarchyInfo {
u32 levels = 1;
SubCommInfo infos[HCCL_LOGIC_TOPO_LEVEL_NUM];
};
struct ChannelInfo {
bool isValid = false;
u32 remoteRank = INVALID_VALUE_RANKID;
CommProtocol protocol = CommProtocol::COMM_PROTOCOL_RESERVED;
EndpointLocType locationType = EndpointLocType::ENDPOINT_LOC_TYPE_RESERVED;
u32 notifyNum = 0;
u32 portGroupSize = 1;
ChannelHandle handle = 0;
HcclMem remoteCclMem;
HcclMem remoteInputGraphMode;
HcclMem remoteOutputGraphMode;
HcclMem remoteInput;
HcclMem remoteOutput;
};
struct AlgResourceCtx {
AlgType algType;
AlgHierarchyInfo algHierarchyInfo;
HcclMem cclInputMem;
HcclMem cclOutputMem;
u32 notifyNumOnMainThread;
u32 slaveThreadNum;
u32 notifyNumPerThread;
ThreadHandle opThread;
uint32_t notifyIds[AICPU_CONTROL_NOTIFY_NUM];
TopoInfo topoInfo;
void* aivCommInfoPtr = nullptr;
};
struct AlgHierarchyInfoForAllLevel {
std::vector<std::vector<std::vector<u32>>> infos;
};
struct AlgResourceCtxSerializable {
AlgType algType;
AlgHierarchyInfoForAllLevel algHierarchyInfo;
HcclMem cclMem;
u32 notifyNumOnMainThread;
u32 slaveThreadNum;
u32 waitTimeout = 0;
u32 fullTimeout = 0;
std::vector<u32> notifyNumPerThread;
void* aivCommInfoPtr = nullptr;
std::vector<ThreadHandle> threads;
ThreadHandle unfoldThread = 0;
std::vector<std::vector<ChannelInfo>> channels;
bool isHcommBatchTransferOnThreadSupported = false;
void* commInfoPtr = nullptr;
void *npu2DpuShmemPtr = nullptr;
void *dpu2NpuShmemPtr = nullptr;
std::vector<u32> ccuKernelNum;
std::vector<CcuKernelHandle> ccuKernels;
u32 topoInfoSeqSize = 0;
TopoInfoWithNetLayerDetails topoInfo;
std::vector<char> Serialize()
{
BinaryStream binaryStream;
binaryStream << algType;
binaryStream << algHierarchyInfo.infos;
binaryStream << cclMem;
binaryStream << notifyNumOnMainThread;
binaryStream << slaveThreadNum;
binaryStream << waitTimeout;
binaryStream << fullTimeout;
binaryStream << notifyNumPerThread;
binaryStream << commInfoPtr;
binaryStream << threads;
binaryStream << unfoldThread;
binaryStream << channels;
binaryStream << isHcommBatchTransferOnThreadSupported;
binaryStream << npu2DpuShmemPtr;
binaryStream << dpu2NpuShmemPtr;
binaryStream << ccuKernelNum;
binaryStream << ccuKernels;
std::vector<char> seq = topoInfo.Serialize();
topoInfoSeqSize = seq.size();
binaryStream << topoInfoSeqSize;
std::vector<char> result;
binaryStream.Dump(result);
result.insert(result.end(), seq.begin(), seq.end());
return result;
}
void DeSerialize(std::vector<char> &data)
{
BinaryStream binaryStream(data);
binaryStream >> algType;
binaryStream >> algHierarchyInfo.infos;
binaryStream >> cclMem;
binaryStream >> notifyNumOnMainThread;
binaryStream >> slaveThreadNum;
binaryStream >> waitTimeout;
binaryStream >> fullTimeout;
binaryStream >> notifyNumPerThread;
binaryStream >> commInfoPtr;
binaryStream >> threads;
binaryStream >> unfoldThread;
binaryStream >> channels;
binaryStream >> isHcommBatchTransferOnThreadSupported;
binaryStream >> npu2DpuShmemPtr;
binaryStream >> dpu2NpuShmemPtr;
binaryStream >> ccuKernelNum;
binaryStream >> ccuKernels;
binaryStream >> topoInfoSeqSize;
size_t startPos = data.size() - topoInfoSeqSize;
std::vector<char> tailData(data.begin() + startPos, data.end());
TopoInfoWithNetLayerDetails topoTemp;
topoTemp.DeSerialize(tailData);
topoInfo = std::move(topoTemp);
}
};
struct DevAicpuOpConfig {
u32 execTimeout = 0;
double multipleDimensionSplitRatio = 0.8;
};
struct OpParam {
void* hcclComm;
char tag[TAG_LENGTH] = "";
char algTag[ALG_TAG_LENGTH] = "";
char fastLaunchTag[ALG_TAG_LENGTH] = "";
char fallbackTag[ALG_MAX_LENGTH] = "";
char commName[COMM_INDENTIFIER_MAX_LENGTH] = "";
char commModeTag[TAG_LENGTH] = "";
aclrtStream stream;
void* inputPtr = nullptr;
u64 inputSize = 0;
void* outputPtr = nullptr;
u64 outputSize = 0;
void* inputSymWindow = nullptr;
void* outputSymWindow = nullptr;
bool supportSymmetricMemory{false};
u64 inputOffset = 0;
u64 outputOffset = 0;
HcclMem hcclBuff;
HcclReduceOp reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED;
u32 root = INVALID_VALUE_RANKID;
u32 userRank = INVALID_VALUE_RANKID;
u32 sendRecvRemoteRank = INVALID_VALUE_RANKID;
OpMode opMode;
bool enableDetour{false};
bool isMc2{false};
bool cacheValid{false};
DevType deviceType = DevType::DEV_TYPE_COUNT;
CommEngine engine = CommEngine::COMM_ENGINE_RESERVED;
AlgType algType;
char algTypeStr[ALG_MAX_LENGTH] = "";
union {
struct {
u64 count;
HcclDataType dataType;
HcclDataType outputType;
u64 strideCount;
} DataDes = {0, HCCL_DATA_TYPE_RESERVED, HCCL_DATA_TYPE_RESERVED, 0};
struct {
HcclDataType sendType;
HcclDataType recvType;
u64 sendCount;
u64 recvCount;
} all2AllDataDes;
struct {
void* counts;
void* displs;
HcclDataType dataType;
} vDataDes;
struct {
HcclDataType sendType;
HcclDataType recvType;
void* sendCounts;
void* recvCounts;
void* sdispls;
void* rdispls;
} all2AllVDataDes;
struct {
HcclDataType sendType;
HcclDataType recvType;
void* sendCountMatrix;
} all2AllVCDataDes;
struct {
HcclSendRecvItem* sendRecvItemsPtr;
u32 itemNum;
} batchSendRecvDataDes;
};
HcclCMDType opType = HcclCMDType::HCCL_CMD_INVALID;
bool isZeroCopy = false;
char algName[OP_ALG_LENGTH] = "";
HcclOpExpansionMode commOpExpansionMode = HcclOpExpansionMode::HCCL_OP_EXPANSION_MODE_INVALID;
OpExecuteConfig opExecuteConfig;
u32 numBlocksLimit = 0;
bool isAivClearEnable = false;
u64 ctxSize = 0;
void* resCtx = nullptr;
ThreadHandle opThread = 0;
u32 aicpuRecordCpuIdx = 0;
u32 dataCount = 0;
DevAicpuOpConfig opConfig;
u64 varMemSize{0};
u8 varData[0];
};
struct AlgDesc {
bool isZeroCopy = false;
bool isAivMode = false;
std::vector<AlgTypeLevel0> level0SupportedAlgos;
std::vector<AlgTypeLevel1> level1SupportedAlgos;
std::vector<AlgTypeLevel2> level2SupportedAlgos;
};
struct Slice {
u64 offset{0};
u64 size{0};
};
struct HcomProInfo {
uint8_t dataType;
uint8_t cmdType;
uint64_t dataCount;
uint32_t rankSize;
uint32_t userRank;
uint32_t blockDim = 0;
uint64_t beginTime;
uint32_t root;
uint32_t slaveThreadNum;
uint64_t commNameLen;
uint64_t algTypeLen;
char tag[MAX_LENGTH];
char commName[MAX_LENGTH];
char algType[MAX_LENGTH];
bool isCapture = false;
bool isAiv = false;
uint8_t reserved[MAX_LENGTH];
};
struct OpParamGraphMode {
char opType[64];
u64 dataCount;
u32 rankSize;
u64 hcclBufferSize;
s64 comm;
char group[MAX_LENGTH];
u64 count = 0;
void* counts = nullptr;
HcclDataType dataType = HCCL_DATA_TYPE_RESERVED;
HcclReduceOp op = HcclReduceOp::HCCL_REDUCE_RESERVED;
HcclCMDType opTypeAiv = HcclCMDType::HCCL_CMD_INVALID;
u32 aivCoreLimit = 0;
bool ifAiv = false;
};
struct ResResponseGraphMode {
u64 opMemSize = 0;
u32 streamNum = 0;
u32 taskNum = 0;
u32 aivCoreNum = 0;
};
struct ResPackGraphMode {
char tag[RES_PACK_TAG_LENGTH];
std::vector<aclrtStream> streams;
void* scratchMemAddr;
u64 scratchMemSize;
};
struct MemRegInfo {
char inputBuffTag[MAX_MEM_TAG_LENGTH];
char outputBuffTag[MAX_MEM_TAG_LENGTH];
std::vector<HcclMemHandle> memHandles;
};
struct AivParamStorage {
u32 aivCoreLimit = 0;
bool aivClearEnable = false;
};
struct OpExchangeInfo {
uint64_t cclBufferSize{0};
u32 root = INVALID_VALUE_RANKID;
HcclCMDType opType = HcclCMDType::HCCL_CMD_INVALID;
OpExecuteConfig opExecuteConfig = OpExecuteConfig::DEFAULT;
HcclReduceOp reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED;
HcclDataType dataType = HcclDataType::HCCL_DATA_TYPE_RESERVED;
u64 count{0};
u32 aivCoreLimit = MAX_NUM_BLOCKS;
char group[MAX_LENGTH] = {0};
char tag[TAG_LENGTH] = {0};
};
}
#endif