aclnnMoeDistributeCombineV3
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
接口功能:当存在TP域通信时,先进行ReduceScatterV通信,再进行AllToAllV通信,最后将接收的数据整合(乘权重再相加);当不存在TP域通信时,进行AllToAllV通信,最后将接收的数据整合(乘权重再相加)。
-
计算公式:
- 不存在TP域通信时:
ataOut=AllToAllV(expandX)xOut=Sum(expertScales∗ataOut+expertScales∗sharedExpertX)ataOut = AllToAllV(expandX)\\ xOut = Sum(expertScales * ataOut + expertScales * sharedExpertX)
- 存在TP域通信时:
rsOut=ReduceScatterV(expandX)ataOut=AllToAllV(rsOut)xOut=Sum(expertScales∗ataOut+expertScales∗sharedExpertX)rsOut = ReduceScatterV(expandX)\\ ataOut = AllToAllV(rsOut)\\ xOut = Sum(expertScales * ataOut + expertScales * sharedExpertX)
注意:该接口必须与
aclnnMoeDistributeDispatchV3配套使用,相当于按aclnnMoeDistributeDispatchV3接口收集数据的路径原路返还。
相较于aclnnMoeDistributeCombineV2接口,该接口变更如下:
-
新增支持特殊专家场景:
-
zeroExpertNum≠0:通过传入大于0的
zeroExpertNum参数使能。Moe(oriXOptional)=0Moe(oriXOptional) = 0
-
copyExpertNum≠0:通过传入大于0的
copyExpertNum参数使能,且需传入有效的oriXOptional参数。Moe(oriXOptional)=oriXOptionalMoe(oriXOptional) = oriXOptional
-
constExpertNum≠0:通过传入大于0的
constExpertNum参数使能,且需传入有效的oriXOptional、constExpertAlpha1Optional、constExpertAlpha2Optional、constExpertVOptional参数。Moe(oriXOptional)=constExpertAlpha1Optional∗oriXOptional+constExpertAlpha2Optional∗constExpertVOptionalMoe(oriXOptional) = constExpertAlpha1Optional * oriXOptional + constExpertAlpha2Optional * constExpertVOptional
-
函数原型
每个算子分为两段式接口,必须先调用 “aclnnMoeDistributeCombineV3GetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeDistributeCombineV3”接口执行计算。
aclnnStatus aclnnMoeDistributeCombineV3GetWorkspaceSize(
const aclTensor* expandX,
const aclTensor* expertIds,
const aclTensor* assistInfoForCombine,
const aclTensor* epSendCounts,
const aclTensor* expertScales,
const aclTensor* tpSendCountsOptional,
const aclTensor* xActiveMaskOptional,
const aclTensor* activationScaleOptional,
const aclTensor* weightScaleOptional,
const aclTensor* groupListOptional,
const aclTensor* expandScalesOptional,
const aclTensor* sharedExpertXOptional,
const aclTensor* elasticInfoOptional,
const aclTensor* oriXOptional,
const aclTensor* constExpertAlpha1Optional,
const aclTensor* constExpertAlpha2Optional,
const aclTensor* constExpertVOptional,
const char* groupEp,
int64_t epWorldSize,
int64_t epRankId,
int64_t moeExpertNum,
const char* groupTp,
int64_t tpWorldSize,
int64_t tpRankId,
int64_t expertShardType,
int64_t sharedExpertNum,
int64_t sharedExpertRankNum,
int64_t globalBS,
int64_t outDtype,
int64_t commQuantMode,
int64_t groupListType,
const char* commAlg,
int64_t zeroExpertNum,
int64_t copyExpertNum,
int64_t constExpertNum,
aclTensor* xOut,
uint64_t* workspaceSize,
aclOpExecutor** executor)
aclnnStatus aclnnMoeDistributeCombineV3(
void* workspace,
uint64_t workspaceSize,
aclOpExecutor* executor,
aclrtStream stream)
aclnnMoeDistributeCombineV3GetWorkspaceSize
-
参数说明
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor expandX 输入 根据 expertIds进行扩展过的token特征。要求2D Tensor。 FLOAT16、BFLOAT16 ND (max(tpWorldSize, 1) * A , H)√ expertIds 输入 每个token的topK个专家索引。 要求2D Tensor。 INT32 ND (BS, K)√ assistInfoForCombine 输入 对应 aclnnMoeDistributeDispatchV3的assistInfoForCombineOut输出。要求1D Tensor。 INT32 ND (A * 128, )√ epSendCounts 输入 对应 aclnnMoeDistributeDispatchV3的epRecvCounts输出。要求1D Tensor。 INT32 ND - √ expertScales 输入 每个token的topK个专家的权重。 要求2D Tensor。 FLOAT32 ND (BS, K)√ tpSendCountsOptional 输入 对应 aclnnMoeDistributeDispatchV3的tpRecvCounts输出。有TP域通信需传参,无TP域通信传空指针。 INT32 ND 当有TP域通信时,shape为 (tpWorldSize, )√ xActiveMaskOptional 输入 标识token是否参与通信。 - 要求是1D或者2D Tensor。可传有效数据或空指针,默认所有token参与通信。
- 各卡BS不一致时所有token需有效。
BOOL ND 当输入为1D时,shape为 (BS, );当输入为2D时,shape为(BS, K)√ activationScaleOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - - weightScaleOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - - groupListOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - - expandScalesOptional 输入 对应 aclnnMoeDistributeDispatchV3的expandScales输出。- FLOAT32 ND (A, )√ sharedExpertXOptional 输入 表示共享专家计算后的token。 - FLOAT16、BFLOAT16 ND (BS, H)√ elasticInfoOptional 输入 EP通信域动态缩容信息。 - INT32 ND - √ oriXOptional 输入 表示未经过FFN(Feed-Forward Neural network)的token数据。 在使能copyExpert或使能constExpert的场景下需要本输入数据。可选择传入有效数据或填空指针,当 copyExpertNum不为0或constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,数据类型需跟expandX保持一致。FLOAT16、BFLOAT16 ND (BS, H)√ constExpertAlpha1Optional 输入 在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND - √ constExpertAlpha2Optional 输入 在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND - √ constExpertVOptional 输入 Device侧的aclTensor,在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND - √ groupEp 输入 EP通信域名称(专家并行通信域)。 字符串长度范围为 [1, 128),不能和groupTp相同。STRING - - - epWorldSize 输入 EP通信域大小。 - INT64 - - - epRankId 输入 EP域本卡ID。 取值范围 [0, epWorldSize),同一个EP通信域中各卡的epRankId不重复。INT64 - - - moeExpertNum 输入 MoE专家数量。 满足 moeExpertNum % (epWorldSize - sharedExpertRankNum) = 0。INT64 - - - groupTp 输入 TP通信域名称(数据并行通信域)。 不能和 groupEp相同。STRING - - - tpWorldSize 输入 TP通信域大小。 - INT64 - - - tpRankId 输入 TP域本卡ID。 同一个TP通信域中各卡的 tpRankId不重复INT64 - - - expertShardType 输入 共享专家卡分布类型。 - INT64 - - - sharedExpertNum 输入 共享专家数。 - INT64 - - - sharedExpertRankNum 输入 共享专家卡数量(一个共享专家可复制部署到多个卡上)。 - INT64 - - - globalBS 输入 EP域全局batch size。
当每个rank的BS一致时,globalBS = BS * epWorldSize或 0
当每个rank的BS不一致时,globalBS = maxBS * epWorldSize,其中maxBS表示单卡BS最大值。INT64 - - - outDtype 输入 用于指定输出x的数据类型,预留参数。 当前版本不支持,传0即可。 INT64 - - - commQuantMode 输入 通信量化类型。 - INT64 - - - groupListType 输入 group List格式,预留参数。 当前版本不支持,传0。 INT64 - - - commAlg 输入 通信亲和内存布局算法。 - STRING - - - zeroExpertNum 输入 零专家数量。 - INT64 - - - copyExpertNum 输入 copy专家数量。 - INT64 - - - constExpertNum 输入 常量专家数量。 - INT64 - - - xOut 输出 处理后的token。 要求为2D Tensor,数据类型/格式与 expandX一致。FLOAT16、BFLOAT16 ND (BS, H)- workspaceSize 输出 返回Device侧需申请的workspace大小。 - UINT64 - - - executor 输出 返回包含算子计算流程的op执行器。 - aclOpExecutor* - - - Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- commAlg 支持nullptr、""、"fullmesh"、"hierarchy";推荐配置"hierarchy"并搭配≥25.0.RC1.1版本驱动;nullptr和""依HCCL环境变量选择算法(不推荐);"fullmesh"通过RDMA直传token;"hierarchy"经机内、跨机两次发送减少跨机数据量。
- 不支持共享专家场景。
- epSendCounts 的shape为 (moeExpertNum + 2 * globalBS * K * serverNum, ),其中K指topK个专家数,前moeExpertNum个数表示从EP通信域各卡接收的token数,后2 * globalBS * K * serverNum个数用于存储机间/机内通信前,combine可提前做reduce的token个数和通信区偏移,globalBS=0时按BS * epWorldSize计算。
- 当前不支持TP域通信。
- xActiveMaskOptional 依commAlg取值,"fullmesh"要求为1D Tensor,shape为(BS, );true需排在false前(例:{true, false, true}非法);"hierarchy"当前版本不支持,传空指针即可。
- expandScalesOptional 要求为1D Tensor,shape为 (A, )。
- sharedExpertXOptional 为预留参数,当前版本不支持,传空指针即可。
- epWorldSize 依commAlg取值,"fullmesh"支持2、3、4、5、6、7、8、16、32、64、128、192、256、384;"hierarchy"支持16、32、64。
- moeExpertNum 依commAlg取值,"fullmesh"支持(0, 1024],"hierarchy"支持(0, 512]。
- groupTp 当前版本不支持,传空字符即可。
- tpWorldSize 当前版本不支持,传0即可。
- tpRankId 当前版本不支持,传0即可。
- expertShardType 当前版本不支持,传0即可。
- sharedExpertNum 当前版本不支持,传0即可。
- sharedExpertRankNum 当前版本不支持,传0即可。
- commQuantMode 取值范围0或2(0表示不量化,2表示int8量化),取值为2仅当commAlg为"hierarchy"或HCCL_INTRA_PCIE_ENABLE=1且HCCL_INTRA_ROCE_ENABLE=0且驱动版本≥25.0.RC1.1时支持。
- expandScalesOptional 要求是一个1D的Tensor。
- elasticInfoOptional 当前版本不支持,传空指针即可。
- oriXOptional 当commAlg="hierarchy"时,当前版本不支持,传空指针即可。
- constExpertAlpha1Optional 预留参数,当前版本不支持,传空指针即可。
- constExpertAlpha2Optional 预留参数,当前版本不支持,传空指针即可。
- constExpertVOptional 预留参数,当前版本不支持,传空指针即可。
- zeroExpertNum 当commAlg="fullmesh"时,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[
moeExpertNum,moeExpertNum + zeroExpertNum)。 - copyExpertNum 当commAlg="fullmesh"时,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[
moeExpertNum + zeroExpertNum,moeExpertNum + zeroExpertNum + copyExpertNum)。 - constExpertNum 当前版本不支持,传0即可。
Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- commAlg 当前版本不支持,传空指针即可。
- epSendCounts 的shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
- 有TP域通信时 tpSendCountsOptional 为1D shape Tensor,shape为 (tpWorldSize, )。
- xActiveMaskOptional 要求为1D或2D Tensor(1D时shape为(BS, ),2D时shape为(BS, K));1D时true需排在false前,2D时token对应K个值全为false则不参与通信。
- expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
- sharedExpertXOptional 要求为2D或3D Tensor(2D时shape为 (BS, H);3D时前两位乘积等于BS、第三维等于H);可传或不传,传入时sharedExpertRankNum需为0。
- epWorldSize 取值支持[2, 768]。
- moeExpertNum 取值范围(0, 1024]。
- groupTp 字符串长度范围为[0, 128),不能和groupEp相同,仅在无tp域通信时支持传空。
- tpWorldSize 取值范围[0, 2],0和1表示无TP域通信,有TP域通信时仅支持2。
- tpRankId 取值范围[0, 1],同一个TP通信域中各卡的tpRankId不重复;无TP域通信时传0即可。
- expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
- sharedExpertNum 当前取值范围[0, 4]。
- sharedExpertRankNum 取值范围[0, epWorldSize);为0时需满足sharedExpertNum为0或1,不为0时需满足sharedExpertRankNum % sharedExpertNum = 0。
- commQuantMode 取值范围0或2(0表示不量化,2表示int8量化),取值为2仅当tpWorldSize < 2时可使能。
- expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
- elasticInfoOptional 当前版本不支持,传空指针即可。
- constExpertAlpha1Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - constExpertAlpha2Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - constExpertVOptional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - zeroExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[
moeExpertNum,moeExpertNum + zeroExpertNum)。 - copyExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[
moeExpertNum + zeroExpertNum,moeExpertNum + zeroExpertNum + copyExpertNum)。 - constExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的常量专家的ID的值是[
moeExpertNum + zeroExpertNum + copyExpertNum,moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum)。
Ascend 950PR/Ascend 950DT:
- commAlg 当前版本不支持,传空指针即可。
- epSendCounts 的shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
- 不支持TP通信域。
- xActiveMaskOptional 要求为1D或2D Tensor(1D时shape为(BS, ),2D时shape为(BS, K));1D时true需排在false前,2D时token对应K个值全为false则不参与通信。
- expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
- sharedExpertXOptional 要求为2D或3D Tensor(2D时shape为 (BS, H);3D时前两位乘积等于BS、第三维等于H);可传或不传,传入时sharedExpertRankNum需为0。
- epWorldSize 取值支持[2, 768]。
- moeExpertNum 取值范围(0, 1024]。
- groupTp 当前版本不支持,传空字符即可。
- tpWorldSize 当前版本不支持,传0即可。
- tpRankId 当前版本不支持,传0即可。
- expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
- sharedExpertNum 当前取值范围[0, 4]。
- sharedExpertRankNum 取值范围[0, epWorldSize);为0时需满足sharedExpertNum为0或1,不为0时需满足sharedExpertRankNum % sharedExpertNum = 0。
- commQuantMode 取值范围0、2、3或4(0表示不量化,2表示int8量化,3表示mxfp8量化e5m2,4表示mxfp8量化e4m3)。
- expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
- elasticInfoOptional 预留参数,当前版本不支持,传空指针即可。
- constExpertAlpha1Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - constExpertAlpha2Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - constExpertVOptional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为
(constExpertNum, H),数据类型需与expandX保持一致。 - zeroExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[
moeExpertNum,moeExpertNum + zeroExpertNum)。 - copyExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[
moeExpertNum + zeroExpertNum,moeExpertNum + zeroExpertNum + copyExpertNum)。 - constExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的常量专家的ID的值是[
moeExpertNum + zeroExpertNum + copyExpertNum,moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum)。
-
返回值
aclnnStatus:返回状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
返回值 错误码 描述 ACLNN_ERR_PARAM_NULLPTR 161001 输入和输出的必选参数Tensor是空指针。 ACLNN_ERR_PARAM_INVALID 161002 输入和输出的数据类型不在支持的范围内。 ACLNN_ERR_INNER_TILING_ERROR 561002 输入和输出的shape不在支持的范围内。 参数的取值不在支持的范围内。
aclnnMoeDistributeCombineV3
-
参数说明
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址。 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMoeDistributeCombineV3GetWorkspaceSize获取。 executor 输入 op执行器,包含了算子计算流程。 stream 输入 指定执行任务的Stream。 -
返回值
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
-
确定性计算:
- aclnnMoeDistributeCombineV3默认确定性实现。
-
驱动约束:
- 算子通信域各节点的驱动版本应当相同。
-
接口配套约束:
aclnnMoeDistributeDispatchV3与aclnnMoeDistributeCombineV3必须配套使用,前者输出的assistInfoForCombineOut、epRecvCountsOut、tpRecvCountsOut、expandScalesOut需直接传入后者对应参数,业务逻辑不可依赖这些Tensor的具体值。
-
参数一致性约束:
- 所有卡的
groupEp、epWorldSize、moeExpertNum、groupTp、tpWorldSize、expertShardType、sharedExpertNum、sharedExpertRankNum、globalBS、commAlg参数及HCCL_BUFFSIZE取值需保持一致,且与aclnnMoeDistributeDispatchV3对应参数一致。
- 所有卡的
-
产品特定约束:
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:单卡包含双DIE(晶粒/裸片),参数说明中的“本卡”均指单DIE。
-
Shape变量约束:
变量 定义与取值范围 A 本卡需分发的最大token数,取值范围如下: - 对于共享专家,要满足
A = BS * epWorldSize * sharedExpertNum / sharedExpertRankNum。 - 对于MoE专家,当globalBS为0时,要满足
A >= BS * epWorldSize * min(localExpertNum, K);当globalBS非0时,要满足A >= globalBS * min(localExpertNum, K)。
H 表示hidden size隐藏层大小: - Atlas A2 训练系列产品/Atlas A2 推理系列产品:(0, 10240]且为32的整数倍。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品/Ascend 950PR/Ascend 950DT:[1024, 8192]。
BS 本卡最终输出token数: - Atlas A2 训练系列产品/Atlas A2 推理系列产品:依commAlg取值,"fullmesh"支持(0, 256];"hierarchy"并且驱动版本≥25.0.RC1.1时支持(0, 512];
- Atlas A3 训练系列产品/Atlas A3 推理系列产品/Ascend 950PR/Ascend 950DT:0 < BS ≤512。
K 表示选取topK个专家:
0 < K ≤16,且0 < K ≤moeExpertNum+zeroExpertNum+copyExpertNum+constExpertNum。serverNum 服务器节点数:
Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅该场景的shape使用了该变量,仅支持2、4、8。localExpertNum 本卡专家数: - 对于共享专家卡,localExpertNum = 1;
- 对于MoE专家卡,localExpertNum =
moeExpertNum/(epWorldSize-sharedExpertRankNum),localExpertNum > 1时不支持TP通信。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品、Ascend 950PR/Ascend 950DT:应满足 0 < localExpertNum * epWorldSize ≤ 2048。
- 对于共享专家,要满足
-
环境变量约束:
-
HCCL_BUFFSIZE:调用本接口前需检查HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- commAlg配置为""或nullptr:依照HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE环境变量配置,选择"fullmesh"或"hierarchy"公式。
- commAlg配置为"fullmesh": 设置大小要求
= 2 * (BS * epWorldSize * min(localExpertNum, K) * H * sizeof(uint16) + 2MB)。 - commAlg配置为"hierarchy": 设置大小要求 (≥ (
moeExpertNum+epWorldSize/ 4) * Align512(maxBS* (H* 2 + 16 * Align8(K))) * 1B + 8MB,其中Align8(x) = ((x + 8 - 1) / 8) * 8,Align512(x) = ((x + 512 - 1) / 512) * 512)。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品/Ascend 950PR/Ascend 950DT:
- ep通信域内:设置大小要求
>= 2且满足>= 2 * (localExpertNum * maxBS * epWorldSize * Align512(Align32(2 * H) + 44) + (K + sharedExpertNum) * maxBS * Align512(2 * H)),localExpertNum需使用MoE专家卡的本卡专家数,其中Align512(x) = ((x + 512 - 1) / 512) * 512,Align32(x) = ((x + 32 - 1) / 32) * 32。 - tp通信域内:设置大小要求>=A * (H * 2 + 128) * 2。
- ep通信域内:设置大小要求
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
HCCL_INTRA_PCIE_ENABLE/HCCL_INTRA_ROCE_ENABLE:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:该环境变量不再推荐使用,建议commAlg配置"hierarchy"。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品/Ascend 950PR/Ascend 950DT:不支持该环境变量。
-
-
通信域使用约束:
- 一个模型中的
aclnnMoeDistributeCombineV3系列算子和aclnnMoeDistributeDispatchV3仅支持相同EP通信域,且该通信域中不允许有其他算子。 - 一个模型中的
aclnnMoeDistributeCombineV3系列算子和aclnnMoeDistributeDispatchV3仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。 - Atlas A3 训练系列产品/Atlas A3 推理系列产品:一个通信域内的节点需在一个超节点内,不支持跨超节点。
- 一个模型中的
-
组网约束:
- Atlas A2 训练系列产品/Atlas A2 推理系列产品:多机场景仅支持交换机组网,不支持双机直连组网。
-
其他约束:
- 公式中的“/”表示整除。
moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum < MAX_INT32。
调用示例
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品 :
本示例支持A2算子运行在卡数为[2, 8]的单机环境中,用户可以根据需要在示例代码中设置EP_WORLD_SIZE_A2为卡数,并更改moeExpertNum,使得moeExpertNum可以被EP_WORLD_SIZE_A2整除。
-
编译算子:算子编译命令如下,moe_distribute_dispatch_v2和moe_distribute_combine_v2算子都需要编译,这两个算子需要成对执行。
bash build.sh --pkg --soc=ascend910b --ops=moe_distribute_dispatch_v2,moe_distribute_combine_v2 -
创建A2示例代码:编译完成后请在算子examples目录下参考已有test_aclnn_moe_distribute_combine_v2.cpp文件,用A2示例代码新建测试文件test_aclnn_moe_distribute_combine_v2.cpp。
-
执行算子样例:示例算子执行命令如下,该命令会执行算子examples目录下所有的示例代码文件。
bash build.sh --run_example --ops=moe_distribute_combine_v2 eager cust -
A2示例代码:
#include <thread> #include <iostream> #include <string> #include <cstring> #include <vector> #include <memory> #include <cstdio> #include "acl/acl.h" #include "hccl/hccl.h" #include "aclnn/opdev/fp16_t.h" #include "aclnnop/aclnn_moe_distribute_dispatch_v3.h" #include "aclnnop/aclnn_moe_distribute_combine_v3.h" #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while(0) struct Args { uint32_t rankId; uint32_t epRankId; uint32_t tpRankId; HcclComm hcclEpComm; HcclComm hcclTpComm; aclrtStream dispatchV3Stream; aclrtStream combineV3Stream; aclrtContext context; }; const uint32_t EP_WORLD_SIZE_A2 = 8; const uint32_t TP_WORLD_SIZE_A2 = 1; const uint32_t DEV_NUM_A2 = EP_WORLD_SIZE_A2 * TP_WORLD_SIZE_A2; int64_t GetShapeSize(const std::vector<int64_t> &shape) { int64_t shape_size = 1; for (auto i : shape) { shape_size *= i; } return shape_size; } template<typename T> int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr, aclDataType dataType, aclTensor **tensor) { auto size = GetShapeSize(shape) * sizeof(T); auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret); ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret); std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i +1] * strides[i + 1]; } *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } void DestroyTensor(aclTensor *tensor) { if (tensor != nullptr) { aclDestroyTensor(tensor); } } void FreeDeviceAddr(void *deviceAddr) { if (deviceAddr != nullptr) { aclrtFree(deviceAddr); } } int launchOneThreadDispatchV3AndCombineV3_A2(Args &args) { int ret = aclrtSetCurrentContext(args.context); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret); char hcomEpName[128] = {0}; ret = HcclGetCommName(args.hcclEpComm, hcomEpName); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1); LOG_PRINT("[INFO] rank = %d, hcomEpName = %s, dispatchV3Stream = %p, combineV3Stream = %p, \ context = %p\n", args.rankId, hcomEpName, args.dispatchV3Stream, args.combineV3Stream, \ args.context); int64_t BS = 32; int64_t H = 7168; int64_t K = 8; int64_t expertShardType = 0; int64_t sharedExpertNum = 0; int64_t sharedExpertRankNum = 0; int64_t moeExpertNum = 256; int64_t quantMode = 0; int64_t globalBS = BS * EP_WORLD_SIZE_A2; int64_t expertTokenNumsType = 1; int64_t outDtype = 0; int64_t commQuantMode = 0; int64_t groupList_type = 1; int64_t localExpertNum; int64_t A; int64_t zeroExpertNum = 0; int64_t copyExpertNum = 0; int64_t constExpertNum = 0; // 仅A3 std::string commAlg = "fullmesh"; if (args.epRankId < sharedExpertRankNum) { localExpertNum = 1; A = globalBS / sharedExpertRankNum; } else { localExpertNum = moeExpertNum / (EP_WORLD_SIZE_A2 - sharedExpertRankNum); A = globalBS * (localExpertNum < K ? localExpertNum : K); } void *xDeviceAddr = nullptr; void *expertIdsDeviceAddr = nullptr; void *scalesDeviceAddr = nullptr; void *expertScalesDeviceAddr = nullptr; void *expandXDeviceAddr = nullptr; void *dynamicScalesDeviceAddr = nullptr; void *assistInfoForCombineDeviceAddr = nullptr; void *expertTokenNumsDeviceAddr = nullptr; void *epRecvCountsDeviceAddr = nullptr; void *tpRecvCountsDeviceAddr = nullptr; void *expandScalesDeviceAddr = nullptr; // 零专家场景输入 void *oriXDeviceAddr = nullptr; void *xOutDeviceAddr = nullptr; aclTensor *x = nullptr; aclTensor *expertIds = nullptr; aclTensor *scales = nullptr; aclTensor *xActiveMask = nullptr; aclTensor *expertScales = nullptr; aclTensor *elasticInfo = nullptr; // A3 aclTensor *expandX = nullptr; aclTensor *dynamicScales = nullptr; aclTensor *assistInfoForCombine = nullptr; // expandIdx aclTensor *expertTokenNums = nullptr; aclTensor *epRecvCounts = nullptr; aclTensor *tpRecvCounts = nullptr; aclTensor *expandScales = nullptr; aclTensor *activationScale = nullptr; // 预留参数 aclTensor *weightScale = nullptr; // 预留参数 aclTensor *groupList = nullptr; // 预留参数 aclTensor *sharedExpertX = nullptr; // A3 aclTensor *oriX = nullptr; aclTensor *constExpertAlpha1 = nullptr; // A3 aclTensor *constExpertAlpha2 = nullptr; // A3 aclTensor *constExpertV = nullptr; // A3 aclTensor *xOut = nullptr; //定义当前场景下各变量维度 std::vector<int64_t> xShape{BS, H}; std::vector<int64_t> expertIdsShape{BS, K}; std::vector<int64_t> scalesShape{moeExpertNum + 1, H}; std::vector<int64_t> expertScalesShape{BS, K}; std::vector<int64_t> expandXShape{TP_WORLD_SIZE_A2 * A, H}; std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE_A2 * A}; std::vector<int64_t> assistInfoForCombineShape{A * 128}; std::vector<int64_t> expertTokenNumsShape{localExpertNum}; std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE_A2 * localExpertNum * EP_WORLD_SIZE_A2}; // 不分层 std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE_A2}; std::vector<int64_t> expandScalesShape{A}; std::vector<int64_t> oriXShape{BS, H}; std::vector<int64_t> xOutShape{BS, H}; int64_t xShapeSize = GetShapeSize(xShape); int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape); int64_t scalesShapeSize = GetShapeSize(scalesShape); int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape); int64_t expandXShapeSize = GetShapeSize(expandXShape); int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape); int64_t assistInfoForCombineShapeSize = GetShapeSize(assistInfoForCombineShape); int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape); int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape); int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape); int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape); int64_t oriXSize = GetShapeSize(oriXShape); int64_t xOutShapeSize = GetShapeSize(xOutShape); std::vector<int16_t> xHostData(xShapeSize, 1); std::vector<int32_t> expertIdsHostData; for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) { for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) { expertIdsHostData.push_back(k_id); } } std::vector<float> scalesHostData(scalesShapeSize, 0.1); std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1); std::vector<int16_t> expandXHostData(expandXShapeSize, 0); std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0); std::vector<int32_t> assistInfoForCombineHostData(assistInfoForCombineShapeSize, 0); std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0); std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0); std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0); std::vector<float> expandScalesHostData(expandScalesShapeSize, 0); std::vector<int16_t> oriXHostData(oriXSize, 1); std::vector<int16_t> xOutHostData(xOutShapeSize, 0); ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(assistInfoForCombineHostData, assistInfoForCombineShape, &assistInfoForCombineDeviceAddr, aclDataType::ACL_INT32, &assistInfoForCombine); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(oriXHostData, oriXShape, &oriXDeviceAddr, aclDataType::ACL_BF16, &oriX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut); CHECK_RET(ret == ACL_SUCCESS, return ret); uint64_t dispatchWorkspaceSize = 0; aclOpExecutor *dispatchExecutor = nullptr; void *dispatchWorkspaceAddr = nullptr; uint64_t combineWorkspaceSize = 0; aclOpExecutor *combineExecutor = nullptr; void *combineWorkspaceAddr = nullptr; /**************************************** 调用dispatch ********************************************/ // 调用第一阶段接口 ret = aclnnMoeDistributeDispatchV3GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), xActiveMask, expertScales, elasticInfo, hcomEpName, EP_WORLD_SIZE_A2, args.epRankId, moeExpertNum, "", TP_WORLD_SIZE_A2, args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS, expertTokenNumsType, commAlg.c_str(), zeroExpertNum, copyExpertNum, constExpertNum, expandX, dynamicScales, assistInfoForCombine, expertTokenNums, epRecvCounts, tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV3GetWorkspaceSize failed. ret = %d \n", ret); return ret); if (dispatchWorkspaceSize > 0) { ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeDispatchV3(dispatchWorkspaceAddr, dispatchWorkspaceSize, dispatchExecutor, args.dispatchV3Stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV3 failed. ret = %d \n", ret); \ return ret); ret = aclrtSynchronizeStreamWithTimeout(args.dispatchV3Stream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] dispatch aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); \ return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV3 execute successfully.\n", args.rankId); /**************************************** 调用combine ********************************************/ // 调用第一阶段接口 ret = aclnnMoeDistributeCombineV3GetWorkspaceSize(expandX, expertIds, assistInfoForCombine, epRecvCounts, expertScales, tpRecvCounts, xActiveMask, activationScale, weightScale, groupList, expandScales, sharedExpertX, elasticInfo, oriX, constExpertAlpha1, constExpertAlpha2, constExpertV, hcomEpName, EP_WORLD_SIZE_A2, args.epRankId, moeExpertNum, "", TP_WORLD_SIZE_A2, args.tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, globalBS, outDtype, commQuantMode, groupList_type, commAlg.c_str(), zeroExpertNum, copyExpertNum, constExpertNum, xOut, &combineWorkspaceSize, &combineExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV3GetWorkspaceSize failed. ret = %d \n", ret); return ret); // 根据第一阶段接口计算出的workspaceSize申请device内存 if (combineWorkspaceSize > 0) { ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeCombineV3(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineV3Stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV3 failed. ret = %d \n", ret); return ret); // (固定写法)同步等待任务执行结束 ret = aclrtSynchronizeStreamWithTimeout(args.combineV3Stream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV3 and aclnnMoeDistributeCombineV3 \ execute successfully.\n", args.rankId); // 释放device资源 if (dispatchWorkspaceSize > 0) { aclrtFree(dispatchWorkspaceAddr); } if (combineWorkspaceSize > 0) { aclrtFree(combineWorkspaceAddr); } DestroyTensor(x); DestroyTensor(expertIds); DestroyTensor(scales); DestroyTensor(xActiveMask); DestroyTensor(expertScales); DestroyTensor(elasticInfo); DestroyTensor(expandX); DestroyTensor(dynamicScales); DestroyTensor(assistInfoForCombine); DestroyTensor(expertTokenNums); DestroyTensor(epRecvCounts); DestroyTensor(tpRecvCounts); DestroyTensor(expandScales); DestroyTensor(activationScale); DestroyTensor(weightScale); DestroyTensor(groupList); DestroyTensor(sharedExpertX); DestroyTensor(oriX); DestroyTensor(constExpertAlpha1); DestroyTensor(constExpertAlpha2); DestroyTensor(constExpertV); DestroyTensor(xOut); FreeDeviceAddr(xDeviceAddr); FreeDeviceAddr(expertIdsDeviceAddr); FreeDeviceAddr(scalesDeviceAddr); FreeDeviceAddr(expertScalesDeviceAddr); FreeDeviceAddr(expandXDeviceAddr); FreeDeviceAddr(dynamicScalesDeviceAddr); FreeDeviceAddr(assistInfoForCombineDeviceAddr); FreeDeviceAddr(expertTokenNumsDeviceAddr); FreeDeviceAddr(epRecvCountsDeviceAddr); FreeDeviceAddr(tpRecvCountsDeviceAddr); FreeDeviceAddr(expandScalesDeviceAddr); FreeDeviceAddr(oriXDeviceAddr); FreeDeviceAddr(xOutDeviceAddr); HcclCommDestroy(args.hcclEpComm); aclrtDestroyStream(args.dispatchV3Stream); aclrtDestroyStream(args.combineV3Stream); aclrtDestroyContext(args.context); LOG_PRINT("[INFO] device_%d DeStroy.\n", args.rankId); aclrtResetDevice(args.rankId); LOG_PRINT("[INFO] device_%d Reset.\n", args.rankId); return 0; } int main(int argc, char *argv[]) { LOG_PRINT("[INFO] run_example_on_A2.\n"); int ret = aclInit(nullptr); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret); aclrtStream dispatchV3Stream[DEV_NUM_A2]; aclrtStream combineV3Stream[DEV_NUM_A2]; aclrtContext context[DEV_NUM_A2]; for (uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) { ret = aclrtSetDevice(rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret); ret = aclrtCreateContext(&context[rankId], rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&dispatchV3Stream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&combineV3Stream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); } int32_t devicesEp[EP_WORLD_SIZE_A2]; for (int32_t epId = 0; epId < EP_WORLD_SIZE_A2; epId++) { devicesEp[epId] = epId; } HcclComm commsEp[EP_WORLD_SIZE_A2]; ret = HcclCommInitAll(EP_WORLD_SIZE_A2, devicesEp, commsEp); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll ep failed, ret %d\n", ret); return ret); Args args[DEV_NUM_A2]; std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM_A2); for (uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) { uint32_t epRankId = rankId / TP_WORLD_SIZE_A2; uint32_t tpRankId = rankId % TP_WORLD_SIZE_A2; args[rankId].rankId = rankId; args[rankId].epRankId = epRankId; args[rankId].tpRankId = tpRankId; args[rankId].hcclEpComm = commsEp[epRankId]; args[rankId].dispatchV3Stream = dispatchV3Stream[rankId]; args[rankId].combineV3Stream = combineV3Stream[rankId]; args[rankId].context = context[rankId]; threads[rankId].reset(new(std::nothrow) std::thread(&launchOneThreadDispatchV3AndCombineV3_A2, std::ref(args[rankId]))); } for(uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) { threads[rankId]->join(); } aclFinalize(); LOG_PRINT("[INFO] aclFinalize success\n"); return 0; }
-
-
Ascend 950PR/Ascend 950DT :请参考aclnnMoeDistributeCombineV2中调用示例的准备部分和示例代码,按照上文的约束说明重新设置涉及的变量,V4接口相较于V3接口新增的场景参数按上述参数说明传值即可。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品:
具体编译和执行过程请参考编译与运行样例。
-
示例代码如下,仅供参考
#include <thread> #include <iostream> #include <string> #include <vector> #include <unordered_set> #include "acl/acl.h" #include "hccl/hccl.h" #include "aclnnop/aclnn_moe_distribute_dispatch_v3.h" #include "aclnnop/aclnn_moe_distribute_combine_v3.h" #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while(0) struct Args { uint32_t rankId; uint32_t epRankId; uint32_t tpRankId; HcclComm hcclEpComm; HcclComm hcclTpComm; aclrtStream dispatchStream; aclrtStream combineStream; aclrtContext context; }; constexpr uint32_t EP_WORLD_SIZE = 2; constexpr uint32_t TP_WORLD_SIZE = 1; constexpr uint32_t DEV_NUM = EP_WORLD_SIZE * TP_WORLD_SIZE; int64_t GetShapeSize(const std::vector<int64_t> &shape) { int64_t shape_size = 1; for (auto i : shape) { shape_size *= i; } return shape_size; } template<typename T> int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr, aclDataType dataType, aclTensor **tensor) { auto size = GetShapeSize(shape) * sizeof(T); auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret); ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret); std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i +1] * strides[i + 1]; } *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int LaunchOneProcessDispatchAndCombine(Args &args) { int ret = aclrtSetCurrentContext(args.context); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret); char hcomEpName[128] = {0}; ret = HcclGetCommName(args.hcclEpComm, hcomEpName); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1); char hcomTpName[128] = {0}; LOG_PRINT("[INFO] rank = %d, hcomEpName = %s, hcomTpName = %s, dispatchStream = %p, combineStream = %p, \ context = %p\n", args.rankId, hcomEpName, hcomTpName, args.dispatchStream, args.combineStream, \ args.context); int64_t BS = 8; int64_t H = 7168; int64_t K = 3; int64_t expertShardType = 0; int64_t sharedExpertNum = 1; int64_t sharedExpertRankNum = 1; int64_t moeExpertNum = 7; int64_t quantMode = 0; int64_t globalBS = BS * EP_WORLD_SIZE; int64_t expertTokenNumsType = 1; int64_t outDtype = 0; int64_t commQuantMode = 0; int64_t groupList_type = 1; int64_t localExpertNum; int64_t A; int64_t zeroExpertNum = 1; int64_t copyExpertNum = 1; int64_t constExpertNum = 1; if (args.epRankId < sharedExpertRankNum) { localExpertNum = 1; A = globalBS / sharedExpertRankNum; } else { localExpertNum = moeExpertNum / (EP_WORLD_SIZE - sharedExpertRankNum); A = globalBS * (localExpertNum < K ? localExpertNum : K); } void *xDeviceAddr = nullptr; void *expertIdsDeviceAddr = nullptr; void *scalesDeviceAddr = nullptr; void *expertScalesDeviceAddr = nullptr; void *expandXDeviceAddr = nullptr; void *dynamicScalesDeviceAddr = nullptr; void *expandIdxDeviceAddr = nullptr; void *expertTokenNumsDeviceAddr = nullptr; void *epRecvCountsDeviceAddr = nullptr; void *tpRecvCountsDeviceAddr = nullptr; void *expandScalesDeviceAddr = nullptr; void *residualXDeviceAddr = nullptr; void *sharedExpertXDeviceAddr = nullptr; void *elasticInfoDeviceAddr = nullptr; void *oriXDeviceAddr = nullptr; void *constExpertAlpha1DeviceAddr = nullptr; void *constExpertAlpha2DeviceAddr = nullptr; void *constExpertVDeviceAddr = nullptr; void *xOutDeviceAddr = nullptr; aclTensor *x = nullptr; aclTensor *expertIds = nullptr; aclTensor *scales = nullptr; aclTensor *expertScales = nullptr; aclTensor *expandX = nullptr; aclTensor *dynamicScales = nullptr; aclTensor *expandIdx = nullptr; aclTensor *expertTokenNums = nullptr; aclTensor *epRecvCounts = nullptr; aclTensor *tpRecvCounts = nullptr; aclTensor *expandScales = nullptr; aclTensor *residualX = nullptr; aclTensor *sharedExpertX = nullptr; aclTensor *elasticInfo = nullptr; aclTensor *oriX = nullptr; aclTensor *constExpertAlpha1 = nullptr; aclTensor *constExpertAlpha2 = nullptr; aclTensor *constExpertV = nullptr; aclTensor *xOut = nullptr; //定义当前场景下各变量维度 std::vector<int64_t> xShape{BS, H}; std::vector<int64_t> expertIdsShape{BS, K}; std::vector<int64_t> scalesShape{moeExpertNum + 1, H}; std::vector<int64_t> expertScalesShape{BS, K}; std::vector<int64_t> expandXShape{TP_WORLD_SIZE * A, H}; std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE * A}; std::vector<int64_t> expandIdxShape{A * 128}; std::vector<int64_t> expertTokenNumsShape{localExpertNum}; std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE * localExpertNum * EP_WORLD_SIZE}; std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE}; std::vector<int64_t> expandScalesShape{A}; std::vector<int64_t> sharedExpertXShape{BS, 1, H}; std::vector<int64_t> elasticInfoShape{4 + EP_WORLD_SIZE * 2}; std::vector<int64_t> oriXShape{BS, H}; std::vector<int64_t> constExpertAlpha1Shape{constExpertNum, H}; std::vector<int64_t> constExpertAlpha2Shape{constExpertNum, H}; std::vector<int64_t> constExpertVShape{constExpertNum, H}; std::vector<int64_t> xOutShape{BS, H}; int64_t xShapeSize = GetShapeSize(xShape); int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape); int64_t scalesShapeSize = GetShapeSize(scalesShape); int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape); int64_t expandXShapeSize = GetShapeSize(expandXShape); int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape); int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape); int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape); int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape); int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape); int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape); int64_t sharedExpertXShapeSize = GetShapeSize(sharedExpertXShape); int64_t elasticInfoSize = GetShapeSize(elasticInfoShape); int64_t oriXSize = GetShapeSize(oriXShape); int64_t constExpertAlpha1Size = GetShapeSize(constExpertAlpha1Shape); int64_t constExpertAlpha2Size = GetShapeSize(constExpertAlpha2Shape); int64_t constExpertVSize = GetShapeSize(constExpertVShape); int64_t xOutShapeSize = GetShapeSize(xOutShape); std::vector<int16_t> xHostData(xShapeSize, 1); std::vector<int32_t> expertIdsHostData; for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) { for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) { expertIdsHostData.push_back(k_id); } } std::vector<float> scalesHostData(scalesShapeSize, 0.1); std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1); std::vector<int16_t> expandXHostData(expandXShapeSize, 0); std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0); std::vector<int32_t> expandIdxHostData(expandIdxShapeSize, 0); std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0); std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0); std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0); std::vector<float> expandScalesHostData(expandScalesShapeSize, 0); std::vector<int16_t> sharedExpertXHostData(sharedExpertXShapeSize, 1); std::vector<int16_t> oriXHostData(oriXSize, 1); std::vector<int16_t> constExpertAlpha1HostData(constExpertAlpha1Size, 0); std::vector<int16_t> constExpertAlpha2HostData(constExpertAlpha2Size, 0); std::vector<int16_t> constExpertVHostData(constExpertVSize, 0); std::vector<int16_t> xOutHostData(xOutShapeSize, 0); ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(sharedExpertXHostData, sharedExpertXShape, &sharedExpertXDeviceAddr, aclDataType::ACL_BF16, &sharedExpertX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(oriXHostData, oriXShape, &oriXDeviceAddr, aclDataType::ACL_BF16, &oriX); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(constExpertAlpha1HostData, constExpertAlpha1Shape, &constExpertAlpha1DeviceAddr, aclDataType::ACL_BF16, &constExpertAlpha1); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(constExpertAlpha2HostData, constExpertAlpha2Shape, &constExpertAlpha2DeviceAddr, aclDataType::ACL_BF16, &constExpertAlpha2); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(constExpertVHostData, constExpertVShape, &constExpertVDeviceAddr, aclDataType::ACL_BF16, &constExpertV); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut); CHECK_RET(ret == ACL_SUCCESS, return ret); uint64_t dispatchWorkspaceSize = 0; aclOpExecutor *dispatchExecutor = nullptr; void *dispatchWorkspaceAddr = nullptr; uint64_t combineWorkspaceSize = 0; aclOpExecutor *combineExecutor = nullptr; void *combineWorkspaceAddr = nullptr; /**************************************** 调用dispatch ********************************************/ // 调用第一阶段接口 ret = aclnnMoeDistributeDispatchV3GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), nullptr, expertScales, elasticInfo, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS, expertTokenNumsType, nullptr, zeroExpertNum, copyExpertNum, constExpertNum, expandX, dynamicScales, expandIdx, expertTokenNums, epRecvCounts, tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV3GetWorkspaceSize failed. ret = %d \n", ret); return ret); if (dispatchWorkspaceSize > 0) { ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeDispatchV3(dispatchWorkspaceAddr, dispatchWorkspaceSize, dispatchExecutor, args.dispatchStream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV3 failed. ret = %d \n", ret); \ return ret); ret = aclrtSynchronizeStreamWithTimeout(args.dispatchStream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] dispatch aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); \ return ret); /**************************************** 调用combine ********************************************/ // 调用第一阶段接口 ret = aclnnMoeDistributeCombineV3GetWorkspaceSize(expandX, expertIds, expandIdx, epRecvCounts, expertScales, tpRecvCounts, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, elasticInfo, oriX, constExpertAlpha1, constExpertAlpha2, constExpertV, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, globalBS, outDtype, commQuantMode, groupList_type, nullptr, zeroExpertNum, copyExpertNum, constExpertNum, xOut, &combineWorkspaceSize, &combineExecutor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV3GetWorkspaceSize failed. ret = %d \n", ret); return ret); // 根据第一阶段接口计算出的workspaceSize申请device内存 if (combineWorkspaceSize > 0) { ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret); } // 调用第二阶段接口 ret = aclnnMoeDistributeCombineV3(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineStream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV3 failed. ret = %d \n", ret); return ret); // (固定写法)同步等待任务执行结束 ret = aclrtSynchronizeStreamWithTimeout(args.combineStream, 10000); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret); return ret); LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV3 and aclnnMoeDistributeCombineV3 \ execute successfully.\n", args.rankId); // 释放device资源 if (dispatchWorkspaceSize > 0) { aclrtFree(dispatchWorkspaceAddr); } if (combineWorkspaceSize > 0) { aclrtFree(combineWorkspaceAddr); } if (x != nullptr) { aclDestroyTensor(x); } if (expertIds != nullptr) { aclDestroyTensor(expertIds); } if (scales != nullptr) { aclDestroyTensor(scales); } if (expertScales != nullptr) { aclDestroyTensor(expertScales); } if (expandX != nullptr) { aclDestroyTensor(expandX); } if (dynamicScales != nullptr) { aclDestroyTensor(dynamicScales); } if (expandIdx != nullptr) { aclDestroyTensor(expandIdx); } if (expertTokenNums != nullptr) { aclDestroyTensor(expertTokenNums); } if (epRecvCounts != nullptr) { aclDestroyTensor(epRecvCounts); } if (tpRecvCounts != nullptr) { aclDestroyTensor(tpRecvCounts); } if (expandScales != nullptr) { aclDestroyTensor(expandScales); } if (residualX != nullptr) { aclDestroyTensor(residualX); } if (sharedExpertX != nullptr) { aclDestroyTensor(sharedExpertX); } if (elasticInfo != nullptr) { aclDestroyTensor(elasticInfo); } if (oriX != nullptr) { aclDestroyTensor(oriX); } if (constExpertAlpha1 != nullptr) { aclDestroyTensor(constExpertAlpha1); } if (constExpertAlpha2 != nullptr) { aclDestroyTensor(constExpertAlpha2); } if (constExpertV != nullptr) { aclDestroyTensor(constExpertV); } if (xOut != nullptr) { aclDestroyTensor(xOut); } if (xDeviceAddr != nullptr) { aclrtFree(xDeviceAddr); } if (expertIdsDeviceAddr != nullptr) { aclrtFree(expertIdsDeviceAddr); } if (scalesDeviceAddr != nullptr) { aclrtFree(scalesDeviceAddr); } if (expertScalesDeviceAddr != nullptr) { aclrtFree(expertScalesDeviceAddr); } if (expandXDeviceAddr != nullptr) { aclrtFree(expandXDeviceAddr); } if (dynamicScalesDeviceAddr != nullptr) { aclrtFree(dynamicScalesDeviceAddr); } if (expandIdxDeviceAddr != nullptr) { aclrtFree(expandIdxDeviceAddr); } if (expertTokenNumsDeviceAddr != nullptr) { aclrtFree(expertTokenNumsDeviceAddr); } if (epRecvCountsDeviceAddr != nullptr) { aclrtFree(epRecvCountsDeviceAddr); } if (expandScalesDeviceAddr != nullptr) { aclrtFree(expandScalesDeviceAddr); } if (tpRecvCountsDeviceAddr != nullptr) { aclrtFree(tpRecvCountsDeviceAddr); } if (sharedExpertXDeviceAddr != nullptr) { aclrtFree(sharedExpertXDeviceAddr); } if (elasticInfoDeviceAddr != nullptr) { aclrtFree(elasticInfoDeviceAddr); } if (oriXDeviceAddr != nullptr) { aclrtFree(oriXDeviceAddr); } if (constExpertAlpha1DeviceAddr != nullptr) { aclrtFree(constExpertAlpha1DeviceAddr); } if (constExpertAlpha2DeviceAddr != nullptr) { aclrtFree(constExpertAlpha2DeviceAddr); } if (constExpertVDeviceAddr != nullptr) { aclrtFree(constExpertVDeviceAddr); } if (xOutDeviceAddr != nullptr) { aclrtFree(xOutDeviceAddr); } HcclCommDestroy(args.hcclEpComm); HcclCommDestroy(args.hcclTpComm); aclrtDestroyStream(args.dispatchStream); aclrtDestroyStream(args.combineStream); aclrtDestroyContext(args.context); aclrtResetDevice(args.rankId); return 0; } int main(int argc, char *argv[]) { int ret = aclInit(nullptr); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret); aclrtStream dispatchStream[DEV_NUM]; aclrtStream combineStream[DEV_NUM]; aclrtContext context[DEV_NUM]; for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { ret = aclrtSetDevice(rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret); ret = aclrtCreateContext(&context[rankId], rankId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&dispatchStream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); ret = aclrtCreateStream(&combineStream[rankId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret); } int32_t devicesEp[TP_WORLD_SIZE][EP_WORLD_SIZE]; for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { devicesEp[tpId][epId] = epId * TP_WORLD_SIZE + tpId; } } HcclComm commsEp[TP_WORLD_SIZE][EP_WORLD_SIZE]; for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { ret = HcclCommInitAll(EP_WORLD_SIZE, devicesEp[tpId], commsEp[tpId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll ep %d failed, ret %d\n", tpId, ret); return ret); } int32_t devicesTp[EP_WORLD_SIZE][TP_WORLD_SIZE]; for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) { devicesTp[epId][tpId] = epId * TP_WORLD_SIZE + tpId; } } HcclComm commsTp[EP_WORLD_SIZE][TP_WORLD_SIZE]; for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) { ret = HcclCommInitAll(TP_WORLD_SIZE, devicesTp[epId], commsTp[epId]); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll tp %d failed, ret %d\n", epId, ret); return ret); } Args args[DEV_NUM]; std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM); for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { uint32_t epRankId = rankId / TP_WORLD_SIZE; uint32_t tpRankId = rankId % TP_WORLD_SIZE; args[rankId].rankId = rankId; args[rankId].epRankId = epRankId; args[rankId].tpRankId = tpRankId; args[rankId].hcclEpComm = commsEp[tpRankId][epRankId]; args[rankId].hcclTpComm = commsTp[epRankId][tpRankId]; args[rankId].dispatchStream = dispatchStream[rankId]; args[rankId].combineStream = combineStream[rankId]; args[rankId].context = context[rankId]; threads[rankId].reset(new(std::nothrow) std::thread(&LaunchOneProcessDispatchAndCombine, std::ref(args[rankId]))); } for(uint32_t rankId = 0; rankId < DEV_NUM; rankId++) { threads[rankId]->join(); } aclFinalize(); LOG_PRINT("[INFO] aclFinalize success\n"); return 0; }