aclnnMoeDistributeCombineV4

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 接口功能:当存在TP域通信时,先进行ReduceScatterV通信,再进行AllToAllV通信,最后将接收的数据整合(乘权重再相加);当不存在TP域通信时,进行AllToAllV通信,最后将接收的数据整合(乘权重再相加)。

    相较于aclnnMoeDistributeCombineV3接口,该接口变更如下:

    新增采集通信耗时功能,记录每张卡的通信时间,通过传入performanceInfoOptional参数使能该特性。该功能推荐结合DeepXTrace工具使用。单次算子调用各卡通信耗时会累加到该Tensor上,用户使用前按需清零。

  • 计算公式:

    • 不存在TP域通信时:

    ataOut=AllToAllV(expandX)xOut=Sum(expertScales∗ataOut+expertScales∗sharedExpertX)ataOut = AllToAllV(expandX)\\ xOut = Sum(expertScales * ataOut + expertScales * sharedExpertX)

    • 存在TP域通信时:

    rsOut=ReduceScatterV(expandX)ataOut=AllToAllV(rsOut)xOut=Sum(expertScales∗ataOut+expertScales∗sharedExpertX)rsOut = ReduceScatterV(expandX)\\ ataOut = AllToAllV(rsOut)\\ xOut = Sum(expertScales * ataOut + expertScales * sharedExpertX)

注意:该接口必须与aclnnMoeDistributeDispatchV4配套使用,相当于按aclnnMoeDistributeDispatchV4接口收集数据的路径原路返还。

函数原型

每个算子分为两段式接口,必须先调用 “aclnnMoeDistributeCombineV4GetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeDistributeCombineV4”接口执行计算。

aclnnStatus aclnnMoeDistributeCombineV4GetWorkspaceSize(
    const aclTensor* expandX,
    const aclTensor* expertIds,
    const aclTensor* assistInfoForCombine,
    const aclTensor* epSendCounts,
    const aclTensor* expertScales,
    const aclTensor* tpSendCountsOptional,
    const aclTensor* xActiveMaskOptional,
    const aclTensor* activationScaleOptional,
    const aclTensor* weightScaleOptional,
    const aclTensor* groupListOptional,
    const aclTensor* expandScalesOptional,
    const aclTensor* sharedExpertXOptional,
    const aclTensor* elasticInfoOptional,
    const aclTensor* oriXOptional,
    const aclTensor* constExpertAlpha1Optional,
    const aclTensor* constExpertAlpha2Optional,
    const aclTensor* constExpertVOptional,
    const aclTensor* performanceInfoOptional,
    const char*      groupEp,
    int64_t          epWorldSize,
    int64_t          epRankId,
    int64_t          moeExpertNum,
    const char*      groupTp,
    int64_t          tpWorldSize,
    int64_t          tpRankId,
    int64_t          expertShardType,
    int64_t          sharedExpertNum,
    int64_t          sharedExpertRankNum,
    int64_t          globalBS,
    int64_t          outDtype,
    int64_t          commQuantMode,
    int64_t          groupListType,
    const char*      commAlg,
    int64_t          zeroExpertNum,
    int64_t          copyExpertNum,
    int64_t          constExpertNum,
    aclTensor*       xOut,
    uint64_t*        workspaceSize,
    aclOpExecutor**  executor)
aclnnStatus aclnnMoeDistributeCombineV4(
    void*           workspace,
    uint64_t        workspaceSize,
    aclOpExecutor*  executor,
    aclrtStream     stream)

aclnnMoeDistributeCombineV4GetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    expandX 输入 根据expertIds进行扩展过的token特征。 要求2D Tensor。 FLOAT16、BFLOAT16 ND (max(tpWorldSize, 1) * A , H)
    expertIds 输入 每个token的topK个专家索引。 要求2D Tensor。 INT32 ND (BS, K)
    assistInfoForCombine 输入 对应aclnnMoeDistributeDispatchV4assistInfoForCombineOut输出。 要求1D Tensor。 INT32 ND (A * 128, )
    epSendCounts 输入 对应aclnnMoeDistributeDispatchV4epRecvCounts输出。 要求1D Tensor。 INT32 ND -
    expertScales 输入 每个token的topK个专家的权重。 要求2D Tensor。 FLOAT32 ND (BS, K)
    tpSendCountsOptional 输入 对应aclnnMoeDistributeDispatchV4tpRecvCounts输出。 有TP域通信需传参,无TP域通信传空指针。 INT32 ND 当有TP域通信时,shape为 (tpWorldSize, )
    xActiveMaskOptional 输入 标识token是否参与通信。
    • 要求是1D或者2D Tensor。可传有效数据或空指针,默认所有token参与通信。
    • 各卡BS不一致时所有token需有效。
    BOOL ND 当输入为1D时,shape为(BS, );当输入为2D时,shape为(BS, K)
    activationScaleOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - -
    weightScaleOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - -
    groupListOptional 输入 预留参数。 当前版本不支持,传空指针即可。 - ND - -
    expandScalesOptional 输入 对应aclnnMoeDistributeDispatchV4expandScales输出。 - FLOAT32 ND (A, )
    sharedExpertXOptional 输入 表示共享专家计算后的token。 - FLOAT16、BFLOAT16 ND (BS, H)
    elasticInfoOptional 输入 EP通信域动态缩容信息。 - INT32 ND -
    oriXOptional 输入 表示未经过FFN(Feed-Forward Neural network)的token数据。 在使能copyExpert或使能constExpert的场景下需要本输入数据。可选择传入有效数据或填空指针,当copyExpertNum不为0或constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,数据类型需跟expandX保持一致。 FLOAT16、BFLOAT16 ND (BS, H)
    constExpertAlpha1Optional 输入 在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND -
    constExpertAlpha2Optional 输入 在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND -
    constExpertVOptional 输入 Device侧的aclTensor,在使能constExpert的场景下需要输入的计算系数。 - FLOAT16、BFLOAT16 ND -
    performanceInfoOptional 输入 表示本卡等待各卡数据的通信时间,单位为us(微秒)。 单次算子调用各卡通信耗时会累加到该Tensor上,算子内部不进行自动清零,因此用户每次启用此Tensor开始记录耗时前需对Tensor清零。 INT64 ND -
    groupEp 输入 EP通信域名称(专家并行通信域)。 字符串长度范围为[1, 128),不能和groupTp相同。 STRING - - -
    epWorldSize 输入 EP通信域大小。 - INT64 - - -
    epRankId 输入 EP域本卡ID。 取值范围[0, epWorldSize),同一个EP通信域中各卡的epRankId不重复。 INT64 - - -
    moeExpertNum 输入 MoE专家数量。 满足moeExpertNum % (epWorldSize - sharedExpertRankNum) = 0 INT64 - - -
    groupTp 输入 TP通信域名称(数据并行通信域)。 不能和groupEp相同。 STRING - - -
    tpWorldSize 输入 TP通信域大小。 - INT64 - - -
    tpRankId 输入 TP域本卡ID。 同一个TP通信域中各卡的tpRankId不重复 INT64 - - -
    expertShardType 输入 共享专家卡分布类型。 - INT64 - - -
    sharedExpertNum 输入 共享专家数。 - INT64 - - -
    sharedExpertRankNum 输入 共享专家卡数量(一个共享专家可复制部署到多个卡上)。 - INT64 - - -
    globalBS 输入 EP域全局batch size。
    • 各rank BS一致时,globalBS = BS * epWorldSize 或 0。
    • 各rank BS不一致时,globalBS = maxBS * epWorldSize(maxBS为单卡BS最大值)。
    INT64 - - -
    outDtype 输入 用于指定输出x的数据类型,预留参数。 当前版本不支持,传0即可。 INT64 - - -
    commQuantMode 输入 通信量化类型。 - INT64 - - -
    groupListType 输入 group List格式,预留参数。 当前版本不支持,传0。 INT64 - - -
    commAlg 输入 通信亲和内存布局算法。 - STRING - - -
    zeroExpertNum 输入 零专家数量。 - INT64 - - -
    copyExpertNum 输入 copy专家数量。 - INT64 - - -
    constExpertNum 输入 常量专家数量。 - INT64 - - -
    xOut 输出 处理后的token。 要求为2D Tensor,数据类型/格式与expandX一致。 FLOAT16、BFLOAT16 ND (BS, H) -
    workspaceSize 输出 返回Device侧需申请的workspace大小。 - UINT64 - - -
    executor 输出 返回包含算子计算流程的op执行器。 - aclOpExecutor* - - -
    Atlas A2 训练系列产品/Atlas A2 推理系列产品:
    • commAlg 支持nullptr、""、"fullmesh"、"hierarchy";推荐配置"hierarchy"并搭配≥25.0.RC1.1版本驱动;nullptr和""依HCCL环境变量选择算法(不推荐);"fullmesh"通过RDMA直传token;"hierarchy"经机内、跨机两次发送减少跨机数据量。
    • 不支持共享专家场景。
    • epSendCounts 的shape为 (moeExpertNum + 2 * globalBS * K * serverNum, ),其中K指topK个专家数,前moeExpertNum个数表示从EP通信域各卡接收的token数,后2 * globalBS * K * serverNum个数用于存储机间/机内通信前,combine可提前做reduce的token个数和通信区偏移,globalBS=0时按BS * epWorldSize计算。
    • 当前不支持TP域通信。
    • xActiveMaskOptional 依commAlg取值,"fullmesh"要求为1D Tensor,shape为(BS, );true需排在false前(例:{true, false, true}非法);"hierarchy"当前版本不支持,传空指针即可。
    • expandScalesOptional 要求为1D Tensor,shape为 (A, )。
    • sharedExpertXOptional 为预留参数,当前版本不支持,传空指针即可。
    • epWorldSize 依commAlg取值,"fullmesh"支持2、3、4、5、6、7、8、16、32、64、128、192、256、384;"hierarchy"支持16、32、64。
    • moeExpertNum 依commAlg取值,"fullmesh"支持(0, 1024],"hierarchy"支持(0, 512]。
    • groupTp 当前版本不支持,传空字符即可。
    • tpWorldSize 当前版本不支持,传0即可。
    • tpRankId 当前版本不支持,传0即可。
    • expertShardType 当前版本不支持,传0即可。
    • sharedExpertNum 当前版本不支持,传0即可。
    • sharedExpertRankNum 当前版本不支持,传0即可。
    • commQuantMode 取值范围0或2(0表示不量化,2表示int8量化),取值为2仅当commAlg为"hierarchy"或HCCL_INTRA_PCIE_ENABLE=1且HCCL_INTRA_ROCE_ENABLE=0且驱动版本≥25.0.RC1.1时支持。
    • expandScalesOptional 要求是一个1D的Tensor。
    • elasticInfoOptional 当前版本不支持,传空指针即可。
    • oriXOptional 当commAlg="hierarchy"时,当前版本不支持,传空指针即可。
    • constExpertAlpha1Optional 预留参数,当前版本不支持,传空指针即可。
    • constExpertAlpha2Optional 预留参数,当前版本不支持,传空指针即可。
    • constExpertVOptional 预留参数,当前版本不支持,传空指针即可。
    • zeroExpertNum 当commAlg="fullmesh"时,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[moeExpertNum, moeExpertNum + zeroExpertNum)。
    • copyExpertNum 当commAlg="fullmesh"时,取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[moeExpertNum + zeroExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum)。
    • constExpertNum 当前版本不支持,传0即可。
    • performanceInfoOptional 可选择传入有效数据或填空指针,传入空指针时表示不使能记录通信耗时功能;当传入有效数据时,要求是一个1D的Tensor,shape为(epWorldSize,),数据类型支持int64;数据格式要求为ND。
    Atlas A3 训练系列产品/Atlas A3 推理系列产品:
    • commAlg 支持"","hierarchy"两种输入方式。"":默认值,不使能hierarchy跨超模板;"hierarchy": 使能跨超模板,该模板仅支持tpWorldSize为1、共享专家为0的场景,且不支持可变BS、二维mask、特殊专家、performanceInfo场景。
    • epSendCounts 的shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
    • 有TP域通信时 tpSendCountsOptional 为1D Tensor,shape为 (tpWorldSize, )。
    • xActiveMaskOptional 要求为1D或2D Tensor(1D时shape为(BS, ),2D时shape为(BS, K));1D时true需排在false前,2D时token对应K个值全为false则不参与通信。
    • expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
    • sharedExpertXOptional 要求为2D或3D Tensor(2D时shape为 (BS, H);3D时前两位乘积等于BS、第三维等于H);可传或不传,传入时sharedExpertRankNum需为0。
    • epWorldSize 取值范围[2, 768];当commAlg="hierarchy"场景时,取值范围为[16, 256],且为16的整数倍。
    • moeExpertNum 取值范围(0, 1024];当commAlg="hierarchy"场景时,取值范围为(0, 512]。
    • groupTp 字符串长度范围为[0, 128),不能和groupEp相同,仅在无tp域通信时支持传空。
    • tpWorldSize 取值范围[0, 2],0和1表示无TP域通信,有TP域通信时仅支持2。
    • tpRankId 取值范围[0, 1],同一个TP通信域中各卡的tpRankId不重复;无TP域通信时传0即可。
    • expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
    • sharedExpertNum 当前取值范围[0, 4]。
    • sharedExpertRankNum 取值范围[0, epWorldSize);为0时需满足sharedExpertNum为0或1,不为0时需满足sharedExpertRankNum % sharedExpertNum = 0。
    • commQuantMode 取值范围0或2(0表示不量化,2表示int8量化),取值为2仅当tpWorldSize < 2时可使能。
    • expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
    • elasticInfoOptional 预留参数,当前版本不支持,传空指针即可。
    • constExpertAlpha1Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为(constExpertNum, H),数据类型需与expandX保持一致。
    • constExpertAlpha2Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为(constExpertNum, H),数据类型需与expandX保持一致。
    • constExpertVOptional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为 (constExpertNum, H),数据类型需与expandX保持一致。
    • zeroExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[moeExpertNum, moeExpertNum + zeroExpertNum)。
    • copyExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[moeExpertNum + zeroExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum)。
    • constExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的常量专家的ID的值是[moeExpertNum + zeroExpertNum + copyExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum)。
    • performanceInfoOptional 可选择传入有效数据或填空指针,传入空指针时表示不使能记录通信耗时功能;当传入有效数据时,要求是一个1D的Tensor,shape为(epWorldSize,),数据类型支持int64;数据格式要求为ND。
    Ascend 950PR/Ascend 950DT:
    • commAlg 当前版本不支持,传空指针即可。
    • epSendCounts 的shape为 (epWorldSize * max(tpWorldSize, 1) * localExpertNum, )。
    • 不支持TP通信域。
    • xActiveMaskOptional 要求为1D或2D Tensor(1D时shape为(BS, ),2D时shape为(BS, K));1D时true需排在false前,2D时token对应K个值全为false则不参与通信。
    • expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
    • sharedExpertXOptional 要求为2D或3D Tensor(2D时shape为 (BS, H);3D时前两位乘积等于BS、第三维等于H);可传或不传,传入时sharedExpertRankNum需为0。
    • epWorldSize 取值支持[2, 768]。
    • moeExpertNum 取值范围(0, 1024]。
    • groupTp 当前版本不支持,传空字符即可。
    • tpWorldSize 当前版本不支持,传0即可。
    • tpRankId 当前版本不支持,传0即可。
    • expertShardType 当前仅支持传0,表示共享专家卡排在MoE专家卡前面。
    • sharedExpertNum 当前取值范围[0, 4]。
    • sharedExpertRankNum 取值范围[0, epWorldSize);为0时需满足sharedExpertNum为0或1,不为0时需满足sharedExpertRankNum % sharedExpertNum = 0。
    • commQuantMode 取值范围0、2、3或4(0表示不量化,2表示int8量化,3表示mxfp8量化e5m2,4表示mxfp8量化e4m3)。
    • expandScalesOptional 预留参数,当前版本不支持,传空指针即可。
    • elasticInfoOptional 预留参数,当前版本不支持,传空指针即可。
    • constExpertAlpha1Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为(constExpertNum, H),数据类型需与expandX保持一致。
    • constExpertAlpha2Optional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为(constExpertNum, H),数据类型需与expandX保持一致。
    • constExpertVOptional 可选择传入有效数据或填空指针,当constExpertNum不为0时必须传入有效输入;当传入有效数据时,要求是一个2D的Tensor,shape为 (constExpertNum, H),数据类型需与expandX保持一致。
    • zeroExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的零专家的ID的值是[moeExpertNum, moeExpertNum + zeroExpertNum)。
    • copyExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1,合法的拷贝专家的ID的值是[moeExpertNum + zeroExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum)。
    • constExpertNum 取值范围:[0, MAX_INT32),MAX_INT32 = 2^31 - 1, 合法的常量专家的ID的值是[moeExpertNum + zeroExpertNum + copyExpertNum, moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum)。
    • performanceInfoOptional 可选择传入有效数据或填空指针,传入空指针时表示不使能记录通信耗时功能;当传入有效数据时,要求是一个1D的Tensor,shape为(epWorldSize,),数据类型支持int64;数据格式要求为ND。
  • 返回值

    aclnnStatus:返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 输入和输出的必选参数Tensor是空指针。
    ACLNN_ERR_PARAM_INVALID 161002 输入和输出的数据类型不在支持的范围内。
    ACLNN_ERR_INNER_TILING_ERROR 561002 输入和输出的shape不在支持的范围内。
    参数的取值不在支持的范围内。

aclnnMoeDistributeCombineV4

  • 参数说明

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址。
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMoeDistributeCombineV4GetWorkspaceSize获取。
    executor 输入 op执行器,包含了算子计算流程。
    stream 输入 指定执行任务的Stream。
  • 返回值

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • 确定性计算

    • aclnnMoeDistributeCombineV4默认确定性实现。
  • 驱动约束

    • 算子通信域各节点的驱动版本应当相同。
  • 接口配套约束

    • aclnnMoeDistributeDispatchV4aclnnMoeDistributeCombineV4必须配套使用,前者输出的assistInfoForCombineOutepRecvCountsOuttpRecvCountsOutexpandScalesOut需直接传入后者对应参数,业务逻辑不可依赖这些Tensor的具体值。
  • 参数一致性约束

    • 所有卡的groupEpepWorldSizemoeExpertNumgroupTptpWorldSizeexpertShardTypesharedExpertNumsharedExpertRankNumglobalBScommAlg参数及HCCL_BUFFSIZE取值需保持一致,且与aclnnMoeDistributeDispatchV4对应参数一致。
  • 产品特定约束

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:单卡包含双DIE(晶粒/裸片),参数说明中的“本卡”均指单DIE。
  • Shape变量约束

    变量 定义与取值范围
    A 本卡需分发的最大token数,取值范围如下:
    • 对于共享专家,要满足A = BS * epWorldSize * sharedExpertNum / sharedExpertRankNum
    • 对于MoE专家,当globalBS为0时,要满足A >= BS * epWorldSize * min(localExpertNum, K);当globalBS非0时,要满足A >= globalBS * min(localExpertNum, K)
    H 表示hidden size隐藏层大小:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:(0, 10240]且为32的整数倍。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:依commAlg取值,当commAlg为"hierarchy"时,H取值范围为[1024, 7168],且为32的整数倍;其余场景下取值范围[1024, 8192]。
    • Ascend 950PR/Ascend 950DT:取值范围[1024, 8192]。
    BS 本卡最终输出token数:
    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:依commAlg取值,"fullmesh"支持(0, 256];"hierarchy"并且驱动版本≥25.0.RC1.1时支持(0, 512];
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:依commAlg取值,"fullmesh_v2"和"hierarchy"取值范围为 (0 < BS ≤ 256), "fullmesh_v1"和""取值范围为 (0 < BS ≤ 512)。
    • Ascend 950PR/Ascend 950DT:0 < BS ≤512。
    K 表示选取topK个专家:
    0 < K ≤16,且0 < K ≤ moeExpertNum+zeroExpertNum+copyExpertNum+constExpertNum
    serverNum 服务器节点数:
    Atlas A2 训练系列产品/Atlas A2 推理系列产品:仅该场景的shape使用了该变量,仅支持2、4、8。
    localExpertNum 本卡专家数:
    • 对于共享专家卡,localExpertNum = 1;
    • 对于MoE专家卡,localExpertNum = moeExpertNum/(epWorldSize-sharedExpertRankNum),localExpertNum > 1时不支持TP通信。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:应满足 0 < localExpertNum * epWorldSize ≤ 2048;当commAlg="hierarchy"时,需满足localExpertNum ≤ 24 且 localExpertNum * epWorldSize ≤ 512。
    • Ascend 950PR/Ascend 950DT:应满足 0 < localExpertNum * epWorldSize ≤ 2048。
  • 环境变量约束

    • HCCL_BUFFSIZE:调用本接口前需检查HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品:
        • commAlg配置为""或nullptr:依照HCCL_INTRA_PCIE_ENABLE和HCCL_INTRA_ROCE_ENABLE环境变量配置,选择"fullmesh"或"hierarchy"公式。
        • commAlg配置为"fullmesh": 要求 = 2 * (BS * epWorldSize * min(localExpertNum, K) * H * sizeof(uint16) + 2MB)
        • commAlg配置为"hierarchy": 要求 >= (moeExpertNum + epWorldSize / 4) * Align512(maxBS * (H * 2 + 16 * Align8(K))) * 1B + 8MB,其中Align8(x) = ((x + 8 - 1) / 8) * 8,Align512(x) = ((x + 512 - 1) / 512) * 512。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:
        • commAlg配置为""或者未配置时:要求 >= 2且满足>= 2 * (localExpertNum * maxBS * epWorldSize * Align512(Align32(2 * H) + 44) + (K + sharedExpertNum) * maxBS * Align512(2 * H))localExpertNum需使用MoE专家卡的本卡专家数,其中Align512(x) = ((x + 512 - 1) / 512) * 512,Align32(x) = ((x + 32 - 1) / 32) * 32
        • commAlg配置为"hierarchy"时:要求取值满足 (moeExpertNum * maxBS * (H * 2 + (3 * (K + 7) / 8 * 8)) * 4 + 64) + 404 * 1024 * 1024
      • Ascend 950PR/Ascend 950DT:
        • commAlg配置为""或者未配置时:要求 >= 2且满足>= 2 * (localExpertNum * maxBS * epWorldSize * Align512(Align32(2 * H) + 44) + (K + sharedExpertNum) * maxBS * Align512(2 * H))localExpertNum需使用MoE专家卡的本卡专家数,其中Align512(x) = ((x + 512 - 1) / 512) * 512,Align32(x) = ((x + 32 - 1) / 32) * 32
    • HCCL_INTRA_PCIE_ENABLE/HCCL_INTRA_ROCE_ENABLE

      • Atlas A2 训练系列产品/Atlas A2 推理系列产品:该环境变量不再推荐使用,建议commAlg配置"hierarchy"。
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品/Ascend 950PR/Ascend 950DT:不支持该环境变量。
    • HCCL_LOGIC_SUPERPOD_ID

      • Atlas A3 训练系列产品/Atlas A3 推理系列产品:当commAlg配置"hierarchy"需根据不同的超节点配置该环境变量,例如两机分别设为export HCCL_LOGIC_SUPERPOD_ID=0export HCCL_LOGIC_SUPERPOD_ID=1
  • 通信域使用约束

    • 一个模型中的aclnnMoeDistributeCombineV4系列算子和aclnnMoeDistributeDispatchV4仅支持相同EP通信域,且该通信域中不允许有其他算子。
    • 一个模型中的aclnnMoeDistributeCombineV4系列算子和aclnnMoeDistributeDispatchV4仅支持相同TP通信域或都不支持TP通信域,有TP通信域时该通信域中不允许有其他算子。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:仅在commAlg配置为"hierarchy"场景下通信域支持跨超节点,其余场景要求一个通信域内的节点需在一个超节点内,不支持跨超节点。
  • 组网约束

    • Atlas A2 训练系列产品/Atlas A2 推理系列产品:多机场景仅支持交换机组网,不支持双机直连组网。
  • 其他约束

    • 公式中的“/”表示整除。
    • moeExpertNum + zeroExpertNum + copyExpertNum + constExpertNum < MAX_INT32

调用示例

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品 :

    本示例支持A2算子运行在卡数为[2, 8]的单机环境中,用户可以根据需要在示例代码中设置EP_WORLD_SIZE_A2为卡数,并更改moeExpertNum,使得moeExpertNum可以被EP_WORLD_SIZE_A2整除。

    • 编译算子:算子编译命令如下,moe_distribute_dispatch_v2和moe_distribute_combine_v2算子都需要编译,这两个算子需要成对执行。

      bash build.sh --pkg --soc=ascend910b --ops=moe_distribute_dispatch_v2,moe_distribute_combine_v2
      
    • 创建A2示例代码:编译完成后请在算子examples目录下参考已有test_aclnn_moe_distribute_combine_v2.cpp文件,用A2示例代码新建测试文件test_aclnn_moe_distribute_combine_v2.cpp。

    • 执行算子样例:示例算子执行命令如下,该命令会执行算子examples目录下所有的示例代码文件。

      bash build.sh --run_example --ops=moe_distribute_combine_v2 eager cust
      
    • A2示例代码:

      #include <thread>
      #include <iostream>
      #include <string>
      #include <cstring>
      #include <vector>
      #include <memory>
      #include <cstdio>
      #include "acl/acl.h"
      #include "hccl/hccl.h"
      #include "aclnn/opdev/fp16_t.h"
      #include "aclnnop/aclnn_moe_distribute_dispatch_v4.h"
      #include "aclnnop/aclnn_moe_distribute_combine_v4.h"
      
      #define CHECK_RET(cond, return_expr) \
          do {                             \
              if (!(cond)) {               \
                  return_expr;             \
              }                            \
          } while (0)
      
      #define LOG_PRINT(message, ...)         \
          do {                                \
              printf(message, ##__VA_ARGS__); \
          } while(0)
      
      struct Args {
          uint32_t rankId;
          uint32_t epRankId;
          uint32_t tpRankId;
          HcclComm hcclEpComm;
          HcclComm hcclTpComm;
          aclrtStream dispatchV4Stream;
          aclrtStream combineV4Stream;
          aclrtContext context;
      };
      
      const uint32_t EP_WORLD_SIZE_A2 = 8;
      const uint32_t TP_WORLD_SIZE_A2 = 1;
      const uint32_t DEV_NUM_A2 = EP_WORLD_SIZE_A2 * TP_WORLD_SIZE_A2;
      
      int64_t GetShapeSize(const std::vector<int64_t> &shape)
      {
          int64_t shape_size = 1;
          for (auto i : shape) {
              shape_size *= i;
          }
          return shape_size;
      }
      
      template<typename T>
      int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
          aclDataType dataType, aclTensor **tensor)
      {
          auto size = GetShapeSize(shape) * sizeof(T);
          auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
          ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
          std::vector<int64_t> strides(shape.size(), 1);
          for (int64_t i = shape.size() - 2; i >= 0; i--) {
              strides[i] = shape[i +1] * strides[i + 1];
          }
          *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
              shape.data(), shape.size(), *deviceAddr);
          return 0;
      }
      
      void DestroyTensor(aclTensor *tensor) {
          if (tensor != nullptr) {
              aclDestroyTensor(tensor);
          }
      }
      
      void FreeDeviceAddr(void *deviceAddr) {
          if (deviceAddr != nullptr) {
              aclrtFree(deviceAddr);
          }
      }
      
      int launchOneThreadDispatchV4AndCombineV4_A2(Args &args)
      {
          int ret = aclrtSetCurrentContext(args.context);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret);
          char hcomEpName[128] = {0};
          ret = HcclGetCommName(args.hcclEpComm, hcomEpName);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1);
          LOG_PRINT("[INFO] rank = %d, hcomEpName = %s, dispatchV4Stream = %p, combineV4Stream = %p, \
                      context = %p\n", args.rankId, hcomEpName, args.dispatchV4Stream, args.combineV4Stream,                 \
                      args.context);
      
          int64_t BS = 32;
          int64_t H = 7168;
          int64_t K = 8;
          int64_t expertShardType = 0;
          int64_t sharedExpertNum = 0;
          int64_t sharedExpertRankNum = 0;
          int64_t moeExpertNum = 256;
          int64_t quantMode = 0;
          int64_t globalBS = BS * EP_WORLD_SIZE_A2;
          int64_t expertTokenNumsType = 1;
          int64_t outDtype = 0;
          int64_t commQuantMode = 0;
          int64_t groupList_type = 1;
          int64_t localExpertNum;
          int64_t A;
          int64_t zeroExpertNum = 0;
          int64_t copyExpertNum = 0;
          int64_t constExpertNum = 0; // 仅A3
          std::string commAlg = "fullmesh";
          if (args.epRankId < sharedExpertRankNum) {
              localExpertNum = 1;
              A = globalBS / sharedExpertRankNum;
          } else {
              localExpertNum = moeExpertNum / (EP_WORLD_SIZE_A2 - sharedExpertRankNum);
              A = globalBS * (localExpertNum < K ? localExpertNum : K);
          }
      
          void *xDeviceAddr = nullptr;
          void *expertIdsDeviceAddr = nullptr;
          void *scalesDeviceAddr = nullptr;
          void *expertScalesDeviceAddr = nullptr;
      
          void *expandXDeviceAddr = nullptr;
          void *dynamicScalesDeviceAddr = nullptr;
          void *assistInfoForCombineDeviceAddr = nullptr;
          void *expertTokenNumsDeviceAddr = nullptr;
          void *epRecvCountsDeviceAddr = nullptr;
          void *tpRecvCountsDeviceAddr = nullptr;
          void *expandScalesDeviceAddr = nullptr;
      
          // 零专家场景输入
          void *oriXDeviceAddr = nullptr;
      
          // 性能打点
          void *performanceInfoDeviceAddr = nullptr;
      
          void *xOutDeviceAddr = nullptr;
      
          aclTensor *x = nullptr;
          aclTensor *expertIds = nullptr;
          aclTensor *scales = nullptr;
          aclTensor *xActiveMask = nullptr;
          aclTensor *expertScales = nullptr;
      
          aclTensor *elasticInfo = nullptr; // A3
          aclTensor *performanceInfo = nullptr;
          aclTensor *expandX = nullptr;
          aclTensor *dynamicScales = nullptr;
          aclTensor *assistInfoForCombine = nullptr; // expandIdx
          aclTensor *expertTokenNums = nullptr;
          aclTensor *epRecvCounts = nullptr;
          aclTensor *tpRecvCounts = nullptr;
          aclTensor *expandScales = nullptr;
      
          aclTensor *activationScale = nullptr; // 预留参数
          aclTensor *weightScale = nullptr; // 预留参数
          aclTensor *groupList = nullptr; // 预留参数
      
          aclTensor *sharedExpertX = nullptr; // A3
      
          aclTensor *oriX = nullptr;
          aclTensor *constExpertAlpha1 = nullptr; // A3
          aclTensor *constExpertAlpha2 = nullptr; // A3
          aclTensor *constExpertV = nullptr; // A3
      
          aclTensor *xOut = nullptr;
      
          //定义当前场景下各变量维度
          std::vector<int64_t> xShape{BS, H};
          std::vector<int64_t> expertIdsShape{BS, K};
          std::vector<int64_t> scalesShape{moeExpertNum + 1, H};
          std::vector<int64_t> expertScalesShape{BS, K};
      
          std::vector<int64_t> expandXShape{TP_WORLD_SIZE_A2 * A, H};
          std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE_A2 * A};
          std::vector<int64_t> assistInfoForCombineShape{A * 128};
          std::vector<int64_t> expertTokenNumsShape{localExpertNum};
          std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE_A2 * localExpertNum * EP_WORLD_SIZE_A2}; // 不分层
          std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE_A2};
          std::vector<int64_t> expandScalesShape{A};
      
          std::vector<int64_t> oriXShape{BS, H};
          std::vector<int64_t> xOutShape{BS, H};
          std::vector<int64_t> performanceInfoShape{EP_WORLD_SIZE_A2};
      
          int64_t xShapeSize = GetShapeSize(xShape);
          int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape);
          int64_t scalesShapeSize = GetShapeSize(scalesShape);
          int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape);
      
          int64_t expandXShapeSize = GetShapeSize(expandXShape);
          int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape);
          int64_t assistInfoForCombineShapeSize = GetShapeSize(assistInfoForCombineShape);
          int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape);
          int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape);
          int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape);
          int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape);
      
          int64_t performanceInfoShapeSize = GetShapeSize(performanceInfoShape);
          int64_t oriXSize = GetShapeSize(oriXShape);
      
          int64_t xOutShapeSize = GetShapeSize(xOutShape);
      
          std::vector<int16_t> xHostData(xShapeSize, 1);
          std::vector<int32_t> expertIdsHostData;
          for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) {
              for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) {
                  expertIdsHostData.push_back(k_id);
              }
          }
      
          std::vector<float> scalesHostData(scalesShapeSize, 0.1);
          std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1);
      
          std::vector<int16_t> expandXHostData(expandXShapeSize, 0);
          std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0);
          std::vector<int32_t> assistInfoForCombineHostData(assistInfoForCombineShapeSize, 0);
          std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0);
          std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0);
          std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0);
          std::vector<float> expandScalesHostData(expandScalesShapeSize, 0);
      
          std::vector<int16_t> oriXHostData(oriXSize, 1);
          std::vector<int16_t> performanceInfoHostData(performanceInfoShapeSize, 0);
          std::vector<int16_t> xOutHostData(xOutShapeSize, 0);
      
      
          ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
      
          ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(assistInfoForCombineHostData, assistInfoForCombineShape, &assistInfoForCombineDeviceAddr, aclDataType::ACL_INT32, &assistInfoForCombine);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
      
          ret = CreateAclTensor(oriXHostData, oriXShape, &oriXDeviceAddr, aclDataType::ACL_BF16, &oriX);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
          ret = CreateAclTensor(performanceInfoHostData, performanceInfoShape, &performanceInfoDeviceAddr, aclDataType::ACL_BF16, &performanceInfo);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
      
          ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut);
          CHECK_RET(ret == ACL_SUCCESS, return ret);
      
      
          uint64_t dispatchWorkspaceSize = 0;
          aclOpExecutor *dispatchExecutor = nullptr;
          void *dispatchWorkspaceAddr = nullptr;
      
          uint64_t combineWorkspaceSize = 0;
          aclOpExecutor *combineExecutor = nullptr;
          void *combineWorkspaceAddr = nullptr;
      
          /**************************************** 调用dispatch ********************************************/
          // 调用第一阶段接口
          ret = aclnnMoeDistributeDispatchV4GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), xActiveMask,
                  expertScales, elasticInfo, performanceInfo, hcomEpName, EP_WORLD_SIZE_A2, args.epRankId, moeExpertNum, "", TP_WORLD_SIZE_A2,
                  args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS,
                  expertTokenNumsType, commAlg.c_str(), zeroExpertNum, copyExpertNum, constExpertNum, expandX, dynamicScales, assistInfoForCombine, expertTokenNums, epRecvCounts,
                  tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor);
      
          CHECK_RET(ret == ACL_SUCCESS,
              LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV4GetWorkspaceSize failed. ret = %d \n", ret); return ret);
      
          if (dispatchWorkspaceSize > 0) {
              ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
          }
          // 调用第二阶段接口
          ret = aclnnMoeDistributeDispatchV4(dispatchWorkspaceAddr, dispatchWorkspaceSize,
                                              dispatchExecutor, args.dispatchV4Stream);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV4 failed. ret = %d \n", ret);  \
                  return ret);
          ret = aclrtSynchronizeStreamWithTimeout(args.dispatchV4Stream, 10000);
                      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] dispatch aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);  \
                  return ret);
          LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV4 execute successfully.\n", args.rankId);
          /**************************************** 调用combine ********************************************/
          // 调用第一阶段接口
          ret = aclnnMoeDistributeCombineV4GetWorkspaceSize(expandX, expertIds,
                                                              assistInfoForCombine, epRecvCounts,
                                                              expertScales, tpRecvCounts,
                                                              xActiveMask, activationScale, weightScale,
                                                              groupList, expandScales, sharedExpertX,
                                                              elasticInfo, oriX, constExpertAlpha1, constExpertAlpha2, constExpertV, performanceInfo,
                                                              hcomEpName, EP_WORLD_SIZE_A2, args.epRankId, moeExpertNum,
                                                              "", TP_WORLD_SIZE_A2, args.tpRankId, expertShardType,
                                                              sharedExpertNum, sharedExpertRankNum, globalBS, outDtype,
                                                              commQuantMode, groupList_type, commAlg.c_str(), zeroExpertNum, copyExpertNum, constExpertNum, xOut,
                                                              &combineWorkspaceSize, &combineExecutor);
          CHECK_RET(ret == ACL_SUCCESS,
              LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV4GetWorkspaceSize failed. ret = %d \n", ret); return ret);
          // 根据第一阶段接口计算出的workspaceSize申请device内存
          if (combineWorkspaceSize > 0) {
              ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
          }
      
          // 调用第二阶段接口
          ret = aclnnMoeDistributeCombineV4(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineV4Stream);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV4 failed. ret = %d \n", ret);
              return ret);
          // (固定写法)同步等待任务执行结束
          ret = aclrtSynchronizeStreamWithTimeout(args.combineV4Stream, 10000);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);
              return ret);
          LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV4 and aclnnMoeDistributeCombineV4                      \
                      execute successfully.\n", args.rankId);
          // 释放device资源
          if (dispatchWorkspaceSize > 0) {
              aclrtFree(dispatchWorkspaceAddr);
          }
          if (combineWorkspaceSize > 0) {
              aclrtFree(combineWorkspaceAddr);
          }
          DestroyTensor(x);
          DestroyTensor(expertIds);
          DestroyTensor(scales);
          DestroyTensor(xActiveMask);
          DestroyTensor(expertScales);
          DestroyTensor(elasticInfo);
          DestroyTensor(performanceInfo);
          DestroyTensor(expandX);
          DestroyTensor(dynamicScales);
          DestroyTensor(assistInfoForCombine);
          DestroyTensor(expertTokenNums);
          DestroyTensor(epRecvCounts);
          DestroyTensor(tpRecvCounts);
          DestroyTensor(expandScales);
          DestroyTensor(activationScale);
          DestroyTensor(weightScale);
          DestroyTensor(groupList);
          DestroyTensor(sharedExpertX);
          DestroyTensor(oriX);
          DestroyTensor(constExpertAlpha1);
          DestroyTensor(constExpertAlpha2);
          DestroyTensor(constExpertV);
          DestroyTensor(xOut);
      
          FreeDeviceAddr(xDeviceAddr);
          FreeDeviceAddr(expertIdsDeviceAddr);
          FreeDeviceAddr(scalesDeviceAddr);
          FreeDeviceAddr(expertScalesDeviceAddr);
          FreeDeviceAddr(expandXDeviceAddr);
          FreeDeviceAddr(dynamicScalesDeviceAddr);
          FreeDeviceAddr(assistInfoForCombineDeviceAddr);
          FreeDeviceAddr(expertTokenNumsDeviceAddr);
          FreeDeviceAddr(epRecvCountsDeviceAddr);
          FreeDeviceAddr(tpRecvCountsDeviceAddr);
          FreeDeviceAddr(expandScalesDeviceAddr);
          FreeDeviceAddr(oriXDeviceAddr);
          FreeDeviceAddr(performanceInfoDeviceAddr);
          FreeDeviceAddr(xOutDeviceAddr);
      
          HcclCommDestroy(args.hcclEpComm);
          aclrtDestroyStream(args.dispatchV4Stream);
          aclrtDestroyStream(args.combineV4Stream);
          aclrtDestroyContext(args.context);
          LOG_PRINT("[INFO] device_%d DeStroy.\n", args.rankId);
          aclrtResetDevice(args.rankId);
          LOG_PRINT("[INFO] device_%d Reset.\n", args.rankId);
          return 0;
      }
      int main(int argc, char *argv[])
      {
          LOG_PRINT("[INFO] run_example_on_A2.\n");
          int ret = aclInit(nullptr);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret);
          aclrtStream dispatchV4Stream[DEV_NUM_A2];
          aclrtStream combineV4Stream[DEV_NUM_A2];
          aclrtContext context[DEV_NUM_A2];
          for (uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) {
              ret = aclrtSetDevice(rankId);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret);
              ret = aclrtCreateContext(&context[rankId], rankId);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret);
              ret = aclrtCreateStream(&dispatchV4Stream[rankId]);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
              ret = aclrtCreateStream(&combineV4Stream[rankId]);
              CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
          }
      
          int32_t devicesEp[EP_WORLD_SIZE_A2];
          for (int32_t epId = 0; epId < EP_WORLD_SIZE_A2; epId++) {
              devicesEp[epId] = epId;
          }
      
          HcclComm commsEp[EP_WORLD_SIZE_A2];
          ret = HcclCommInitAll(EP_WORLD_SIZE_A2, devicesEp, commsEp);
          CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclCommInitAll ep failed, ret %d\n", ret); return ret);
      
          Args args[DEV_NUM_A2];
          std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM_A2);
          for (uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) {
              uint32_t epRankId = rankId / TP_WORLD_SIZE_A2;
              uint32_t tpRankId = rankId % TP_WORLD_SIZE_A2;
      
              args[rankId].rankId = rankId;
              args[rankId].epRankId = epRankId;
              args[rankId].tpRankId = tpRankId;
              args[rankId].hcclEpComm = commsEp[epRankId];
              args[rankId].dispatchV4Stream = dispatchV4Stream[rankId];
              args[rankId].combineV4Stream = combineV4Stream[rankId];
              args[rankId].context = context[rankId];
              threads[rankId].reset(new(std::nothrow) std::thread(&launchOneThreadDispatchV4AndCombineV4_A2, std::ref(args[rankId])));
          }
      
          for(uint32_t rankId = 0; rankId < DEV_NUM_A2; rankId++) {
              threads[rankId]->join();
          }
      
          aclFinalize();
          LOG_PRINT("[INFO] aclFinalize success\n");
          return 0;
      }
      
  • Ascend 950PR/Ascend 950DT :请参考aclnnMoeDistributeCombineV2中调用示例的准备部分和示例代码,按照上文的约束说明重新设置涉及的变量,V4接口相较于V3接口新增的场景参数按上述参数说明传值即可。

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:

    具体编译和执行过程请参考编译与运行样例

  • 示例代码如下,仅供参考

    #include <thread>
    #include <iostream>
    #include <string>
    #include <vector>
    #include <unordered_set>
    #include "acl/acl.h"
    #include "hccl/hccl.h"
    #include "aclnnop/aclnn_moe_distribute_dispatch_v4.h"
    #include "aclnnop/aclnn_moe_distribute_combine_v4.h"
    
    #define CHECK_RET(cond, return_expr) \
        do {                             \
            if (!(cond)) {               \
                return_expr;             \
            }                            \
        } while (0)
    
    #define LOG_PRINT(message, ...)         \
        do {                                \
            printf(message, ##__VA_ARGS__); \
        } while(0)
    
    struct Args {
        uint32_t rankId;
        uint32_t epRankId;
        uint32_t tpRankId;
        HcclComm hcclEpComm;
        HcclComm hcclTpComm;
        aclrtStream dispatchStream;
        aclrtStream combineStream;
        aclrtContext context;
    };
    
    constexpr uint32_t EP_WORLD_SIZE = 2;
    constexpr uint32_t TP_WORLD_SIZE = 1;
    constexpr uint32_t DEV_NUM = EP_WORLD_SIZE * TP_WORLD_SIZE;
    
    int64_t GetShapeSize(const std::vector<int64_t> &shape)
    {
        int64_t shape_size = 1;
        for (auto i : shape) {
            shape_size *= i;
        }
        return shape_size;
    }
    
    template<typename T>
    int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
        aclDataType dataType, aclTensor **tensor)
    {
        auto size = GetShapeSize(shape) * sizeof(T);
        auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
        ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
        std::vector<int64_t> strides(shape.size(), 1);
        for (int64_t i = shape.size() - 2; i >= 0; i--) {
            strides[i] = shape[i +1] * strides[i + 1];
        }
        *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
            shape.data(), shape.size(), *deviceAddr);
        return 0;
    }
    
    int LaunchOneProcessDispatchAndCombine(Args &args)
    {
        int ret = aclrtSetCurrentContext(args.context);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret);
    
        char hcomEpName[128] = {0};
        ret = HcclGetCommName(args.hcclEpComm, hcomEpName);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1);
        char hcomTpName[128] = {0};
        LOG_PRINT("[INFO] rank = %d, hcomEpName = %s, hcomTpName = %s, dispatchStream = %p, combineStream = %p, \
                    context = %p\n", args.rankId, hcomEpName, hcomTpName, args.dispatchStream, args.combineStream,                 \
                    args.context);
    
        int64_t BS = 8;
        int64_t H = 7168;
        int64_t K = 3;
        int64_t expertShardType = 0;
        int64_t sharedExpertNum = 1;
        int64_t sharedExpertRankNum = 1;
        int64_t moeExpertNum = 7;
        int64_t quantMode = 0;
        int64_t globalBS = BS * EP_WORLD_SIZE;
        int64_t expertTokenNumsType = 1;
        int64_t outDtype = 0;
        int64_t commQuantMode = 0;
        int64_t groupList_type = 1;
        int64_t localExpertNum;
        int64_t A;
        int64_t zeroExpertNum = 1;
        int64_t copyExpertNum = 1;
        int64_t constExpertNum = 1;
        if (args.epRankId < sharedExpertRankNum) {
            localExpertNum = 1;
            A = globalBS / sharedExpertRankNum;
        } else {
            localExpertNum = moeExpertNum / (EP_WORLD_SIZE - sharedExpertRankNum);
            A = globalBS * (localExpertNum < K ? localExpertNum : K);
        }
    
        void *xDeviceAddr = nullptr;
        void *expertIdsDeviceAddr = nullptr;
        void *scalesDeviceAddr = nullptr;
        void *expertScalesDeviceAddr = nullptr;
    
        void *expandXDeviceAddr = nullptr;
        void *dynamicScalesDeviceAddr = nullptr;
        void *expandIdxDeviceAddr = nullptr;
        void *expertTokenNumsDeviceAddr = nullptr;
        void *epRecvCountsDeviceAddr = nullptr;
        void *tpRecvCountsDeviceAddr = nullptr;
        void *expandScalesDeviceAddr = nullptr;
        void *residualXDeviceAddr = nullptr;
        void *sharedExpertXDeviceAddr = nullptr;
    
        void *elasticInfoDeviceAddr = nullptr;
        void *oriXDeviceAddr = nullptr;
        void *constExpertAlpha1DeviceAddr = nullptr;
        void *constExpertAlpha2DeviceAddr = nullptr;
        void *constExpertVDeviceAddr = nullptr;
    
        void *xOutDeviceAddr = nullptr;
    
        aclTensor *x = nullptr;
        aclTensor *expertIds = nullptr;
        aclTensor *scales = nullptr;
        aclTensor *expertScales = nullptr;
    
        aclTensor *expandX = nullptr;
        aclTensor *dynamicScales = nullptr;
        aclTensor *expandIdx = nullptr;
        aclTensor *expertTokenNums = nullptr;
        aclTensor *epRecvCounts = nullptr;
        aclTensor *tpRecvCounts = nullptr;
        aclTensor *expandScales = nullptr;
        aclTensor *residualX = nullptr;
        aclTensor *sharedExpertX = nullptr;
    
    
        aclTensor *elasticInfo = nullptr;
        aclTensor *oriX = nullptr;
        aclTensor *constExpertAlpha1 = nullptr;
        aclTensor *constExpertAlpha2 = nullptr;
        aclTensor *constExpertV = nullptr;
    
        aclTensor *xOut = nullptr;
    
        //定义当前场景下各变量维度
        std::vector<int64_t> xShape{BS, H};
        std::vector<int64_t> expertIdsShape{BS, K};
        std::vector<int64_t> scalesShape{moeExpertNum + 1, H};
        std::vector<int64_t> expertScalesShape{BS, K};
    
        std::vector<int64_t> expandXShape{TP_WORLD_SIZE * A, H};
        std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE * A};
        std::vector<int64_t> expandIdxShape{A * 128};
        std::vector<int64_t> expertTokenNumsShape{localExpertNum};
        std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE * localExpertNum * EP_WORLD_SIZE};
        std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE};
        std::vector<int64_t> expandScalesShape{A};
        std::vector<int64_t> sharedExpertXShape{BS, 1, H};
    
    
        std::vector<int64_t> elasticInfoShape{4 + EP_WORLD_SIZE * 2};
        std::vector<int64_t> oriXShape{BS, H};
        std::vector<int64_t> constExpertAlpha1Shape{constExpertNum, H};
        std::vector<int64_t> constExpertAlpha2Shape{constExpertNum, H};
        std::vector<int64_t> constExpertVShape{constExpertNum, H};
    
        std::vector<int64_t> xOutShape{BS, H};
    
        int64_t xShapeSize = GetShapeSize(xShape);
        int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape);
        int64_t scalesShapeSize = GetShapeSize(scalesShape);
        int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape);
    
        int64_t expandXShapeSize = GetShapeSize(expandXShape);
        int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape);
        int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape);
        int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape);
        int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape);
        int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape);
        int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape);
        int64_t sharedExpertXShapeSize = GetShapeSize(sharedExpertXShape);
    
        int64_t elasticInfoSize = GetShapeSize(elasticInfoShape);
        int64_t oriXSize = GetShapeSize(oriXShape);
        int64_t constExpertAlpha1Size = GetShapeSize(constExpertAlpha1Shape);
        int64_t constExpertAlpha2Size = GetShapeSize(constExpertAlpha2Shape);
        int64_t constExpertVSize = GetShapeSize(constExpertVShape);
    
        int64_t xOutShapeSize = GetShapeSize(xOutShape);
    
        std::vector<int16_t> xHostData(xShapeSize, 1);
        std::vector<int32_t> expertIdsHostData;
        for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) {
            for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) {
                expertIdsHostData.push_back(k_id);
            }
        }
    
        std::vector<float> scalesHostData(scalesShapeSize, 0.1);
        std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1);
    
        std::vector<int16_t> expandXHostData(expandXShapeSize, 0);
        std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0);
        std::vector<int32_t> expandIdxHostData(expandIdxShapeSize, 0);
        std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0);
        std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0);
        std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0);
        std::vector<float> expandScalesHostData(expandScalesShapeSize, 0);
        std::vector<int16_t> sharedExpertXHostData(sharedExpertXShapeSize, 1);
    
        std::vector<int16_t> oriXHostData(oriXSize, 1);
        std::vector<int16_t> constExpertAlpha1HostData(constExpertAlpha1Size, 0);
        std::vector<int16_t> constExpertAlpha2HostData(constExpertAlpha2Size, 0);
        std::vector<int16_t> constExpertVHostData(constExpertVSize, 0);
    
        std::vector<int16_t> xOutHostData(xOutShapeSize, 0);
    
    
        ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
    
        ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
            ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(sharedExpertXHostData, sharedExpertXShape, &sharedExpertXDeviceAddr, aclDataType::ACL_BF16, &sharedExpertX);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
    
        ret = CreateAclTensor(oriXHostData, oriXShape, &oriXDeviceAddr, aclDataType::ACL_BF16, &oriX);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(constExpertAlpha1HostData, constExpertAlpha1Shape, &constExpertAlpha1DeviceAddr, aclDataType::ACL_BF16, &constExpertAlpha1);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(constExpertAlpha2HostData, constExpertAlpha2Shape, &constExpertAlpha2DeviceAddr, aclDataType::ACL_BF16, &constExpertAlpha2);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(constExpertVHostData, constExpertVShape, &constExpertVDeviceAddr, aclDataType::ACL_BF16, &constExpertV);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
    
        ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
    
    
    
        uint64_t dispatchWorkspaceSize = 0;
        aclOpExecutor *dispatchExecutor = nullptr;
        void *dispatchWorkspaceAddr = nullptr;
    
        uint64_t combineWorkspaceSize = 0;
        aclOpExecutor *combineExecutor = nullptr;
        void *combineWorkspaceAddr = nullptr;
    
        /**************************************** 调用dispatch ********************************************/
        // 调用第一阶段接口
        ret = aclnnMoeDistributeDispatchV4GetWorkspaceSize(x, expertIds, (quantMode > 0 ? scales : nullptr), nullptr,
                expertScales, elasticInfo, nullptr, // performanceInfoOptional
                hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE,
                args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS,
                expertTokenNumsType, nullptr, zeroExpertNum, copyExpertNum, constExpertNum, expandX, dynamicScales, expandIdx, expertTokenNums, epRecvCounts,
                tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor);
    
        CHECK_RET(ret == ACL_SUCCESS,
            LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV4GetWorkspaceSize failed. ret = %d \n", ret); return ret);
    
        if (dispatchWorkspaceSize > 0) {
            ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
        }
        // 调用第二阶段接口
        ret = aclnnMoeDistributeDispatchV4(dispatchWorkspaceAddr, dispatchWorkspaceSize,
                                            dispatchExecutor, args.dispatchStream);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV4 failed. ret = %d \n", ret);  \
                return ret);
        ret = aclrtSynchronizeStreamWithTimeout(args.dispatchStream, 10000);
                    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] dispatch aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);  \
                return ret);
        /**************************************** 调用combine ********************************************/
        // 调用第一阶段接口
        ret = aclnnMoeDistributeCombineV4GetWorkspaceSize(expandX, expertIds,
                                                            expandIdx, epRecvCounts,
                                                            expertScales, tpRecvCounts,
                                                            nullptr, nullptr, nullptr,
                                                            nullptr, nullptr, nullptr,
                                                            elasticInfo, oriX, constExpertAlpha1, constExpertAlpha2, constExpertV,
                                                            nullptr, // performanceInfoOptional
                                                            hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum,
                                                            hcomTpName, TP_WORLD_SIZE, args.tpRankId, expertShardType,
                                                            sharedExpertNum, sharedExpertRankNum, globalBS, outDtype,
                                                            commQuantMode, groupList_type, nullptr, zeroExpertNum, copyExpertNum, constExpertNum, xOut,
                                                            &combineWorkspaceSize, &combineExecutor);
        CHECK_RET(ret == ACL_SUCCESS,
            LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV4GetWorkspaceSize failed. ret = %d \n", ret); return ret);
        // 根据第一阶段接口计算出的workspaceSize申请device内存
        if (combineWorkspaceSize > 0) {
            ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
        }
    
        // 调用第二阶段接口
        ret = aclnnMoeDistributeCombineV4(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineStream);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV4 failed. ret = %d \n", ret);
            return ret);
        // (固定写法)同步等待任务执行结束
        ret = aclrtSynchronizeStreamWithTimeout(args.combineStream, 10000);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSynchronizeStreamWithTimeout failed. ret = %d \n", ret);
            return ret);
        LOG_PRINT("[INFO] device_%d aclnnMoeDistributeDispatchV4 and aclnnMoeDistributeCombineV4                      \
                    execute successfully.\n", args.rankId);
        // 释放device资源
        if (dispatchWorkspaceSize > 0) {
            aclrtFree(dispatchWorkspaceAddr);
        }
        if (combineWorkspaceSize > 0) {
            aclrtFree(combineWorkspaceAddr);
        }
        if (x != nullptr) {
            aclDestroyTensor(x);
        }
        if (expertIds != nullptr) {
            aclDestroyTensor(expertIds);
        }
        if (scales != nullptr) {
            aclDestroyTensor(scales);
        }
        if (expertScales != nullptr) {
            aclDestroyTensor(expertScales);
        }
    
        if (expandX != nullptr) {
            aclDestroyTensor(expandX);
        }
        if (dynamicScales != nullptr) {
            aclDestroyTensor(dynamicScales);
        }
        if (expandIdx != nullptr) {
            aclDestroyTensor(expandIdx);
        }
        if (expertTokenNums != nullptr) {
            aclDestroyTensor(expertTokenNums);
        }
        if (epRecvCounts != nullptr) {
            aclDestroyTensor(epRecvCounts);
        }
        if (tpRecvCounts != nullptr) {
            aclDestroyTensor(tpRecvCounts);
        }
        if (expandScales != nullptr) {
            aclDestroyTensor(expandScales);
        }
        if (residualX != nullptr) {
            aclDestroyTensor(residualX);
        }
        if (sharedExpertX != nullptr) {
            aclDestroyTensor(sharedExpertX);
        }
        if (elasticInfo != nullptr) {
            aclDestroyTensor(elasticInfo);
        }
        if (oriX != nullptr) {
            aclDestroyTensor(oriX);
        }
        if (constExpertAlpha1 != nullptr) {
            aclDestroyTensor(constExpertAlpha1);
        }
        if (constExpertAlpha2 != nullptr) {
            aclDestroyTensor(constExpertAlpha2);
        }
        if (constExpertV != nullptr) {
            aclDestroyTensor(constExpertV);
        }
    
        if (xOut != nullptr) {
            aclDestroyTensor(xOut);
        }
        if (xDeviceAddr != nullptr) {
            aclrtFree(xDeviceAddr);
        }
        if (expertIdsDeviceAddr != nullptr) {
            aclrtFree(expertIdsDeviceAddr);
        }
        if (scalesDeviceAddr != nullptr) {
            aclrtFree(scalesDeviceAddr);
        }
        if (expertScalesDeviceAddr != nullptr) {
            aclrtFree(expertScalesDeviceAddr);
        }
        if (expandXDeviceAddr != nullptr) {
            aclrtFree(expandXDeviceAddr);
        }
        if (dynamicScalesDeviceAddr != nullptr) {
            aclrtFree(dynamicScalesDeviceAddr);
        }
        if (expandIdxDeviceAddr != nullptr) {
            aclrtFree(expandIdxDeviceAddr);
        }
        if (expertTokenNumsDeviceAddr != nullptr) {
            aclrtFree(expertTokenNumsDeviceAddr);
        }
        if (epRecvCountsDeviceAddr != nullptr) {
            aclrtFree(epRecvCountsDeviceAddr);
        }
        if (expandScalesDeviceAddr != nullptr) {
            aclrtFree(expandScalesDeviceAddr);
        }
        if (tpRecvCountsDeviceAddr != nullptr) {
            aclrtFree(tpRecvCountsDeviceAddr);
        }
        if (sharedExpertXDeviceAddr != nullptr) {
            aclrtFree(sharedExpertXDeviceAddr);
        }
    
        if (elasticInfoDeviceAddr != nullptr) {
            aclrtFree(elasticInfoDeviceAddr);
        }
        if (oriXDeviceAddr != nullptr) {
            aclrtFree(oriXDeviceAddr);
        }
        if (constExpertAlpha1DeviceAddr != nullptr) {
            aclrtFree(constExpertAlpha1DeviceAddr);
        }
        if (constExpertAlpha2DeviceAddr != nullptr) {
            aclrtFree(constExpertAlpha2DeviceAddr);
        }
        if (constExpertVDeviceAddr != nullptr) {
            aclrtFree(constExpertVDeviceAddr);
        }
    
        if (xOutDeviceAddr != nullptr) {
            aclrtFree(xOutDeviceAddr);
        }
    
        HcclCommDestroy(args.hcclEpComm);
        HcclCommDestroy(args.hcclTpComm);
        aclrtDestroyStream(args.dispatchStream);
        aclrtDestroyStream(args.combineStream);
        aclrtDestroyContext(args.context);
        aclrtResetDevice(args.rankId);
    
        return 0;
    }
    
    int main(int argc, char *argv[])
    {
        int ret = aclInit(nullptr);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret);
    
        aclrtStream dispatchStream[DEV_NUM];
        aclrtStream combineStream[DEV_NUM];
        aclrtContext context[DEV_NUM];
        for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            ret = aclrtSetDevice(rankId);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateContext(&context[rankId], rankId);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateStream(&dispatchStream[rankId]);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
            ret = aclrtCreateStream(&combineStream[rankId]);
            CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
        }
    
        int32_t devicesEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
        for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
            for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
                devicesEp[tpId][epId] = epId * TP_WORLD_SIZE + tpId;
            }
        }
    
        HcclComm commsEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
        for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
            ret = HcclCommInitAll(EP_WORLD_SIZE, devicesEp[tpId], commsEp[tpId]);
            CHECK_RET(ret == ACL_SUCCESS,
                        LOG_PRINT("[ERROR] HcclCommInitAll ep %d failed, ret %d\n", tpId, ret); return ret);
        }
    
        int32_t devicesTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
        for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
            for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
                devicesTp[epId][tpId] = epId * TP_WORLD_SIZE + tpId;
            }
        }
    
        HcclComm commsTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
        for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
            ret = HcclCommInitAll(TP_WORLD_SIZE, devicesTp[epId], commsTp[epId]);
            CHECK_RET(ret == ACL_SUCCESS,
                        LOG_PRINT("[ERROR] HcclCommInitAll tp %d failed, ret %d\n", epId, ret); return ret);
        }
    
        Args args[DEV_NUM];
        std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM);
        for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            uint32_t epRankId = rankId / TP_WORLD_SIZE;
            uint32_t tpRankId = rankId % TP_WORLD_SIZE;
    
            args[rankId].rankId = rankId;
            args[rankId].epRankId = epRankId;
            args[rankId].tpRankId = tpRankId;
            args[rankId].hcclEpComm = commsEp[tpRankId][epRankId];
            args[rankId].hcclTpComm = commsTp[epRankId][tpRankId];
            args[rankId].dispatchStream = dispatchStream[rankId];
            args[rankId].combineStream = combineStream[rankId];
            args[rankId].context = context[rankId];
            threads[rankId].reset(new(std::nothrow) std::thread(&LaunchOneProcessDispatchAndCombine, std::ref(args[rankId])));
        }
    
        for(uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
            threads[rankId]->join();
        }
    
        aclFinalize();
        LOG_PRINT("[INFO] aclFinalize success\n");
    
        return 0;
    }