* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef COLL_ALG_COMM_H
#define COLL_ALG_COMM_H
#include <string>
#include <vector>
#include <map>
#include <set>
#include <unordered_set>
#include "hccl_common.h"
#include "hccl_types.h"
#include "transport_pub.h"
#include "stream_pub.h"
#include "local_notify.h"
#include "hccl_trace_info.h"
#include "common.h"
#include "threadManage.h"
#include "template_v1_utils.h"
namespace hccl {
using RankId = u32;
enum class OpMode {
OFFLOAD = 0,
OPBASE = 1
};
enum class DeviceMode {
HOST = 0,
AICPU = 1
};
enum class AlgExpansionMode {
SUPERK_HOST = 0,
SUPERK_AICPU = 1,
SUPERK_AIV = 2,
SUPERK_RECURSIVE = 4
};
enum class TransportStatus {
INIT,
READY,
STOP
};
enum TransportMemType {
CCL_INPUT = 0,
CCL_OUTPUT,
SCRATCH,
PARAM_INPUT,
PARAM_OUTPUT,
AIV_INPUT,
AIV_OUTPUT,
USER_MEM,
RESERVED
};
enum class TransportLinkType : int {
RESERVED = -1,
HCCS = 0,
SIO = 1,
RDMA = 2,
MAX_NUM
};
struct TransportRequest {
bool isValid = false;
RankId localUserRank = 0;
RankId remoteUserRank = 0;
TransportMemType inputMemType = TransportMemType::RESERVED;
TransportMemType outputMemType = TransportMemType::RESERVED;
bool isUsedRdma = false;
u32 notifyNum = 0;
TransportLinkType linkType = TransportLinkType::RESERVED;
};
struct SingleSubCommTransport {
std::vector<TransportRequest> transportRequests;
std::vector<LINK> links;
std::vector<TransportStatus> status;
u64 taskNum = 0;
std::map<u32, u32> userRank2subCommRank;
std::map<u32, u32> subCommRank2UserRank;
bool supportDataReceivedAck = false;
LinkMode linkMode = LinkMode::LINK_DUPLEX_MODE;
bool enableUseOneDoorbell = false;
bool needVirtualLink = false;
std::vector<LINK> virtualLinks;
bool isZeroCopy = false;
};
using LevelNSubCommTransport = std::vector<SingleSubCommTransport>;
using OpCommTransport = std::vector<LevelNSubCommTransport>;
struct AlgResourceRequest {
u64 scratchMemSize = 0;
u32 streamNum = 0;
u32 notifyNum = 0;
u64 aivBufferRequest = 0;
DeviceMode mode = DeviceMode::HOST;
OpCommTransport opTransport;
bool isInGraphCaptureZeroCopy = false;
void Describe()
{
HCCL_DEBUG("[AlgResourceRequest], scratchMemSize[%u], streamNum[%u], notifyNum[%u], aivBufferRequest[%llu], "
"DeviceMode[%d].", scratchMemSize, streamNum, notifyNum, aivBufferRequest, mode);
};
};
struct AlgResourceResponse {
DeviceMem cclInputMem;
DeviceMem cclOutputMem;
DeviceMem paramInputMem;
DeviceMem paramOutputMem;
DeviceMem scratchMem;
DeviceMem aivInputMem;
DeviceMem aivOutputMem;
DeviceMem aivCommInfoMem;
std::vector<Stream> slaveStreams;
std::vector<Stream> slaveDevStreams;
std::vector<std::shared_ptr<LocalNotify> > notifiesMain;
std::vector<std::shared_ptr<LocalNotify> > notifiesAux;
std::vector<std::shared_ptr<LocalNotify> > notifiesDevMain;
std::vector<std::shared_ptr<LocalNotify> > notifiesDevAux;
OpCommTransport opTransportResponse;
OpCommTransport opTransportResponseBackUp;
std::vector<std::shared_ptr<ThreadManage>> threadManage;
};
enum class BatchSendRecvCurMode {
SEND = 0,
RECV = 1,
SEND_RECV = 2,
SEND_RECV_RESERVED
};
struct OpParam {
std::string tag = "";
Stream stream;
void* inputPtr = nullptr;
u64 inputSize = 0;
void* outputPtr = nullptr;
u64 outputSize = 0;
HcclReduceOp reduceType = HcclReduceOp::HCCL_REDUCE_RESERVED;
SyncMode syncMode = SyncMode::DEFAULT_TIMEWAITSYNCMODE;
RankId root = INVALID_VALUE_RANKID;
RankId dstRank = 0;
RankId srcRank = 0;
bool aicpuUnfoldMode = false;
uint8_t aicpuCacheEnable = 0;
bool isCapture = false;
HcclTraceInfo* opBaseAtraceInfo = nullptr;
union {
struct {
u64 count;
HcclDataType dataType;
u64 strideCount;
} DataDes = {0, HCCL_DATA_TYPE_RESERVED, 0};
struct {
void* counts;
void* displs;
HcclDataType dataType;
} VDataDes;
struct {
HcclDataType sendType;
HcclDataType recvType;
u64 sendCount;
u64 recvCount;
void* sendCounts;
void* recvCounts;
void* sdispls;
void* rdispls;
void* sendCountMatrix;
} All2AllDataDes;
struct {
HcclSendRecvItem* sendRecvItemsPtr;
u32 itemNum;
u32 curIterNum;
BatchSendRecvCurMode curMode;
u8* isDirectRemoteRank;
} BatchSendRecvDataDes;
struct {
u32 itemNum;
u32 queueNum;
u32 queueIdx;
} BatchWriteDataDes;
};
HcclCMDType opType = HcclCMDType::HCCL_CMD_INVALID;
bool supportZeroCopy = false;
bool isZeroCopy = false;
u8 aclGraphZeroCopyEnable = 0;
bool supportRoceDirect = false;
bool isNpuDirectRoce = false;
s32 aivTag = 0;
u32 index = 0;
bool isInplaceError = false;
u32 rankSize = 0;
u32 aivCoreLimit = 0;
u8 deterministic = 0;
u32 srTag = 0;
u32 localGroupRank = 0;
bool isGroupMode = false;
bool supportSymmetricMemory = false;
void* inputSymWindow = nullptr;
u64 inputOffset = 0;
void* outputSymWindow = nullptr;
u64 outputOffset = 0;
bool needIncreLink = false;
inline HcclDataType GetDataType() const
{
if (opType == HcclCMDType::HCCL_CMD_ALLGATHER_V || opType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER_V) {
return VDataDes.dataType;
}
return DataDes.dataType;
}
inline u64 GetDataCount(RankId rankId) const
{
if (opType == HcclCMDType::HCCL_CMD_ALLGATHER_V || opType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER_V) {
return static_cast<const u64 *>(VDataDes.counts)[rankId];
}
return DataDes.count;
}
inline u64 GetStrideCount() const
{
if (opType == HcclCMDType::HCCL_CMD_ALLGATHER_V || opType == HcclCMDType::HCCL_CMD_REDUCE_SCATTER_V) {
return 0;
}
return DataDes.strideCount;
}
bool operator<(const OpParam &other) const noexcept {
switch (opType) {
case HcclCMDType::HCCL_CMD_ALLGATHER:
return std::tie(opType, DataDes.count, DataDes.dataType, tag, aivCoreLimit, deterministic, isCapture) <
std::tie(other.opType, other.DataDes.count, other.DataDes.dataType, other.tag, other.aivCoreLimit, other.deterministic, other.isCapture);
case HcclCMDType::HCCL_CMD_ALLTOALL:
return std::tie(opType, All2AllDataDes.sendCount, All2AllDataDes.sendType,
All2AllDataDes.recvCount, All2AllDataDes.recvType, tag, aivCoreLimit, deterministic, isCapture) <
std::tie(other.opType, other.All2AllDataDes.sendCount, other.All2AllDataDes.sendType,
other.All2AllDataDes.recvCount, other.All2AllDataDes.recvType, other.tag, other.aivCoreLimit, other.deterministic, other.isCapture);
case HcclCMDType::HCCL_CMD_BROADCAST:
return std::tie(opType, DataDes.count, DataDes.dataType, root, tag, aivCoreLimit, deterministic, isCapture) <
std::tie(other.opType, other.DataDes.count, other.DataDes.dataType, other.root, other.tag, other.aivCoreLimit, other.deterministic, other.isCapture);
case HcclCMDType::HCCL_CMD_ALLREDUCE:
case HcclCMDType::HCCL_CMD_REDUCE_SCATTER:
return std::tie(opType, DataDes.count, DataDes.dataType, reduceType, tag, aivCoreLimit, deterministic, isCapture) <
std::tie(other.opType, other.DataDes.count, other.DataDes.dataType, other.reduceType, other.tag, other.aivCoreLimit, other.deterministic, other.isCapture);
default:
break;
}
return true;
}
};
struct AlgDesc {
bool isZeroCopy = false;
bool isAivMode = false;
bool isAivCrossNode = false;
bool isLastSelect = false;
s32 deterministic = -1;
s32 aivTagNum = 1;
AlgType algType;
std::vector<AlgTypeLevel0> level0SupportedAlgos;
std::vector<AlgTypeLevel1> level1SupportedAlgos;
std::vector<AlgTypeLevel2> level2SupportedAlgos;
};
struct ResourceLimit {
bool ifLimit = false;
bool ifCompileForAiv = false;
u32 aivCoreLimit = 0;
};
}
#endif