aclnnMoeUpdateExpert

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

本API支持负载均衡和专家剪枝功能。经过映射后的专家表和mask可传入Moe层进行数据分发和处理。

  • 负载均衡:为了解决负载不均衡的场景,该算子可以完成每个token的topK个专家逻辑专家号到物理卡号的映射。计算方法如下所示:

    负载均衡对于expert_ids中的第i个值,即第i个token:

    place_num = 8
    F = 2
    new_expert_id = eplb_table[table_offset + 1]
    expert_id = expert_ids[i]
    table_offset = expert_id * F
    if (eplb_table[table_offset] == 1):
        new_expert_id = eplb_table[table_offset + 1]
    else:
        if (balance_mode == 0):
            mode_value = ceil(world_size / eplb_table[table_offset])
            place_idx = local_rank_id / mode_value + 1
        else:
            place_idx = i % place_num
    new_expert_id = eplb_table[table_offset + place_idx]
    
  • 专家剪枝:支持根据阈值对token发送的topK个专家进行剪枝。计算方法如下所示:

    将shape为(BS,)的active_mask进行broadcast成为shape为(BS,K)的active_mask_tensor,其中BS对应为False的专家会直接被剪枝。对于active_mask_tensor为True的元素,满足条件也将被剪枝。

    active_mask_tensor = broadcast(active_mask, (BS, K))
    for i in range(BS):
        expert_scales_vec[:] = sum(expert_scales[i, :] * pruning_threshold[:])
        balanced_active_mask[i, :] = (expert_scales[i, :] < expert_scales_vec[:]) && active_mask_tensor[i, :]
    

函数原型

每个算子分为两段式接口,必须先调用 “aclnnMoeUpdateExpertGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMoeUpdateExpert”接口执行计算。

aclnnStatus aclnnMoeUpdateExpertGetWorkspaceSize(
    const aclTensor* expertIds,
    const aclTensor* eplbTable,
    const aclTensor* expertScalesOptional,
    const aclTensor* pruningThresholdOptional,
    const aclTensor* activeMaskOptional,
    int64_t          localRankId,
    int64_t          worldSize,
    int64_t          balanceMode,
    aclTensor*       balancedExpertIds,
    aclTensor*       balancedActiveMask,
    uint64_t*        workspaceSize,
    aclOpExecutor**  executor)
aclnnStatus aclnnMoeUpdateExpert(
    void*           workspace,
    uint64_t        workspaceSize,
    aclOpExecutor*  executor,
    aclrtStream     stream)

aclnnMoeUpdateExpertGetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    expertIds(aclTensor*) 输入 每个token的topK个专家索引。 要求为2D Tensor,shape为 (BS, K)。 INT32、INT64 ND 2 ×
    eplbTable(aclTensor*) 输入 逻辑专家到物理专家的映射表(外部需保证值正确)。
    • 共world_size * place_per_rank个专家实例。
    • 每行第一列为对应逻辑专家的部署实例数(取值[1, world_size]),后[1, count]列为实例编号(取值[0, world_size*place_per_rank),且不重复)。
    • 要求为2D Tensor,shape为 (moeExperNum, F)。
    INT32 ND 2 ×
    expertScalesOptional (aclTensor*) 输入 每个token的topK个专家的scale权重。
    • 需保证token内部按降序排列。
    • 可传有效数据或空指针,传有效数据时pruningThresholdOptional必须同时传有效数据。
    • 要求为2D Tensor,shape为 (BS, K)。
    FLOAT16、BFLOAT16、FLOAT ND 2 ×
    pruningThresholdOptional (aclTensor*) 输入 专家scale权重的最小阈值(token对应专家scale小于阈值时会被剪枝)。
    • 可传有效数据或空指针,传有效数据时expertScalesOptional必须同时传有效数据。
    • 要求为1D或2D Tensor,shape为 (K,) 或 (1, K)。
    FLOAT ND 1/2 ×
    activeMaskOptional (aclTensor*) 输入 表示token是否参与通信。
    • 可传有效数据或空指针,传有效数据时expertScalesOptional和pruningThresholdOptional必须同时传有效数据。
    • true表示参与通信,且true需排在false前(例:{true, false, true}非法)。
    • 传空指针时,默认所有token参与通信。
    • 要求为1D Tensor,shape为 (BS,)。
    BOOL ND 1 ×
    localRankId(int64_t) 输入 本卡Id。
    • balanceMode=0时取值范围[0, worldSize)。
    • 同一个通信域中各卡的localRankId不重复。
    INT64 - - -
    worldSize(int64_t) 输入 通信域大小。 balanceMode=0时取值范围[2, 768]。 INT64 - - -
    balanceMode(int64_t) 输入 均衡规则,默认值为0。
    • 0:按rank分发。
    • 1:按token分发。
    • 取值范围[0, 1]。
    INT64 - - -
    balancedExpertIds (aclTensor*) 输出 映射后每个token的topK个专家所在物理专家的实例编号。
    • 要求为2D Tensor,shape为 (BS, K)。
    • 数据类型、数据格式与expertIds保持一致。
    INT32、INT64 ND 2 ×
    balancedActiveMask (aclTensor*) 输出 剪枝后均衡的activeMask,仅当expertScalesOptional和pruningThresholdOptional传有效数据时有效。 要求为2D Tensor,shape为 (BS, K)。 BOOL ND 2 ×
    workspaceSize(uint64_t*) 输出 返回需要在Device侧申请的workspace大小。 - UINT64 - - -
    executor(aclOpExecutor**) 输出 返回op执行器,包含了算子的计算流程。 - aclOpExecutor* - - -
  • 返回值

    aclnnStatus:返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 输入和输出的必选参数Tensor是空指针。
    ACLNN_ERR_PARAM_INVALID 161002 输入和输出的数据类型不在支持的范围内。
    ACLNN_ERR_INNER_TILING_ERROR 561002 1. 输入和输出的shape不在支持的范围内;
    2. 参数的取值不在支持的范围内。

aclnnMoeUpdateExpert

  • 参数说明

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址。
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMoeUpdateExpertGetWorkspaceSize获取。
    executor 输入 op执行器,包含了算子计算流程。
    stream 输入 指定执行任务的Stream。
  • 返回值

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  1. 确定性计算:

    • aclnnMoeUpdateExpert默认确定性实现。
  2. 接口配套与调用顺序
    该接口必须与aclnnMoeDistributeDispatchV2aclnnMoeDistributeCombineV2/aclnnMoeDistributeCombineAddRmsNorm接口配套使用,调用顺序固定为
    aclnnMoeUpdateExpertaclnnMoeDistributeDispatchV2aclnnMoeDistributeCombineV2/aclnnMoeDistributeCombineAddRmsNorm

    或与aclnnMoeDistributeDispatchV3aclnnMoeDistributeCombineV3/aclnnMoeDistributeCombineAddRmsNormV2接口配套使用,调用顺序固定为
    aclnnMoeUpdateExpertaclnnMoeDistributeDispatchV3aclnnMoeDistributeCombineV3/aclnnMoeDistributeCombineAddRmsNormV2;具体参考调用示例

  3. 参数一致性要求
    调用过程中使用的worldSizemoeExpertNum参数取值,所有卡需保持一致,网络不同层中也需保持一致,且需与aclnnMoeDistributeDispatchV2aclnnMoeDistributeCombineV2/aclnnMoeDistributeCombineAddRmsNorm的对应参数一致。

  4. 硬件相关定义
    Atlas A3 训练系列产品/Atlas A3 推理系列产品:单卡包含双DIE(“晶粒”或“裸片”),因此参数说明中的“本卡”均指单DIE

  5. 参数shape格式约束

    • BS:本卡最终输出的token数量,取值范围 ( 0 < BS ≤ 512 )。
    • K:选取的topK个专家,取值范围 ( 0 < K ≤ 16 ),且需满足 ( 0 < K ≤ moeExpertNum )。
    • moeExpertNum:MoE专家数量,取值范围 (0, 1024]。
    • F:映射表eplbTable的列数,取值范围 [2, worldSize + 1];第一列为逻辑专家的部署实例数(值>0),后F-1列为对应的物理卡号。
    • 实例总数限制:所有卡部署的MoE专家实例总数最多1024,即 ( place_per_rank * world_size ≤ 1024 )(place_per_rank为单卡部署实例数)。
    • 单卡实例数一致性:每张卡部署的专家实例数必须相同。

调用示例

以Atlas A3 训练系列产品/Atlas A3 推理系列产品以及Ascend 950PR/Ascend 950DT为例,调用MoeUpdateExpert,MoeDistributeDispatchV2和MoeDistributeCombineAddRmsNorm算子。

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

```Cpp
#include <thread>
#include <iostream>
#include <string>
#include <cstring>
#include <vector>
#include "acl/acl.h"
#include "hccl/hccl.h"
#include "aclnnop/aclnn_moe_update_expert.h"
#include "aclnnop/aclnn_moe_distribute_dispatch_v2.h"
#include "aclnnop/aclnn_moe_distribute_combine_v2.h"

#define CHECK_RET(cond, return_expr) \
    do {                             \
        if (!(cond)) {               \
            return_expr;             \
        }                            \
    } while (0)

#define LOG_PRINT(message, ...)         \
    do {                                \
        printf(message, ##__VA_ARGS__); \
    } while(0)

struct Args {
    uint32_t rankId;
    uint32_t epRankId;
    uint32_t tpRankId;
    HcclComm hcclEpComm;
    HcclComm hcclTpComm;
    aclrtStream eplbStream;
    aclrtStream dispatchStream;
    aclrtStream combineStream;
    aclrtContext context;
};

constexpr uint32_t EP_WORLD_SIZE = 2;
constexpr uint32_t TP_WORLD_SIZE = 1;
constexpr uint32_t DEV_NUM = EP_WORLD_SIZE * TP_WORLD_SIZE;

inline int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
    int64_t shape_size = 1;
    for (auto i : shape) {
        shape_size *= i;
    }
    return shape_size;
}

template<typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
    aclDataType dataType, aclTensor **tensor)
{
    auto size = GetShapeSize(shape) * sizeof(T);
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret: %d\n", ret); return ret);
    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMemcpy failed. ret: %d\n", ret); return ret);
    std::vector<int64_t> strides(shape.size(), 1);
    for (int64_t i = shape.size() - 2; i >= 0; i--) {
        strides[i] = shape[i + 1] * strides[i + 1];
    }
    *tensor = aclCreateTensor(
        shape.data(), shape.size(), dataType, strides.data(), 0,
        aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr);
    return 0;
}

int LaunchOneProcessUpdateExpertAndDispatchAndCombine(Args &args)
{
    int ret = aclrtSetCurrentContext(args.context);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetCurrentContext failed, ret %d\n", ret); return ret);

    char hcomEpName[128] = {0};
    ret = HcclGetCommName(args.hcclEpComm, hcomEpName);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] HcclGetEpCommName failed, ret %d\n", ret); return -1);
    char hcomTpName[128] = {0};

    int64_t BS = 8;
    int64_t H = 7168;
    int64_t K = 2;
    int64_t F = 2;
    int64_t expertShardType = 0;
    int64_t sharedExpertNum = 0;
    int64_t sharedExpertRankNum = 0;
    int64_t moeExpertNum = 2;
    int64_t quantMode = 0;
    int64_t globalBS = BS * EP_WORLD_SIZE;
    int64_t balanceMode = 0;
    int64_t expertTokenNumsType = 1;
    int64_t outDtype = 0;
    int64_t commQuantMode = 0;
    int64_t groupList_type = 1;
    int64_t localExpertNum;
    int64_t A;
    if (args.epRankId < sharedExpertRankNum) {
        // 共享专家卡
        localExpertNum = 1;
        A = globalBS / sharedExpertRankNum;
    } else {
        // Moe专家卡
        localExpertNum = moeExpertNum / (EP_WORLD_SIZE - sharedExpertRankNum);
        A = globalBS * (localExpertNum < K ? localExpertNum : K);
    }

    /* 根据当前场景,构造device侧输入输出变量 */
    // 声明device侧输入输出变量
    void *xDeviceAddr = nullptr;
    void *expertIdsDeviceAddr = nullptr;
    void *eplbTableDeviceAddr = nullptr;
    void *scalesDeviceAddr = nullptr;
    void *expertScalesDeviceAddr = nullptr;
    void *expandXDeviceAddr = nullptr;
    void *dynamicScalesDeviceAddr = nullptr;
    void *expandIdxDeviceAddr = nullptr;
    void *expertTokenNumsDeviceAddr = nullptr;
    void *epRecvCountsDeviceAddr = nullptr;
    void *tpRecvCountsDeviceAddr = nullptr;
    void *expandScalesDeviceAddr = nullptr;
    void *residualXDeviceAddr = nullptr;
    void *sharedExpertXDeviceAddr = nullptr;
    void *gammaDeviceAddr = nullptr;
    void *yOutDeviceAddr = nullptr;
    void *rstdOutDeviceAddr = nullptr;
    void *xOutDeviceAddr = nullptr;
    void *balancedExpertIdsDeviceAddr = nullptr;
    void *balancedActiveMaskDeviceAddr = nullptr;

    aclTensor *x = nullptr;
    aclTensor *expertIds = nullptr;
    aclTensor *eplbTable = nullptr;
    aclTensor *scales = nullptr;
    aclTensor *expertScales = nullptr;
    aclTensor *expandX = nullptr;
    aclTensor *dynamicScales = nullptr;
    aclTensor *expandIdx = nullptr;
    aclTensor *expertTokenNums = nullptr;
    aclTensor *epRecvCounts = nullptr;
    aclTensor *tpRecvCounts = nullptr;
    aclTensor *expandScales = nullptr;
    aclTensor *residualX = nullptr;
    aclTensor *sharedExpertX = nullptr;
    aclTensor *gamma = nullptr;
    aclTensor *yOut = nullptr;
    aclTensor *rstdOut = nullptr;
    aclTensor *xOut = nullptr;
    aclTensor *balancedExpertIds = nullptr;
    aclTensor *balancedActiveMask = nullptr;

    // 定义当前场景下各变量维度
    std::vector<int64_t> xShape{BS, H};
    std::vector<int64_t> expertIdsShape{BS, K};
    std::vector<int64_t> eplbTableShape{moeExpertNum, F};
    std::vector<int64_t> scalesShape{(sharedExpertRankNum > 0) ? moeExpertNum + 1 : moeExpertNum, H};
    std::vector<int64_t> expertScalesShape{BS, K};
    std::vector<int64_t> expandXShape{TP_WORLD_SIZE * A, H};
    std::vector<int64_t> dynamicScalesShape{TP_WORLD_SIZE * A};
    std::vector<int64_t> expandIdxShape{A * 128};
    std::vector<int64_t> expertTokenNumsShape{localExpertNum};
    std::vector<int64_t> epRecvCountsShape{TP_WORLD_SIZE * localExpertNum * EP_WORLD_SIZE};
    std::vector<int64_t> tpRecvCountsShape{TP_WORLD_SIZE * localExpertNum};
    std::vector<int64_t> expandScalesShape{A};
    std::vector<int64_t> residualXShape{BS, 1, H};
    std::vector<int64_t> sharedExpertXShape{BS, 1, H};
    std::vector<int64_t> gammaShape{H, };
    std::vector<int64_t> yOutShape{BS, 1, H};
    std::vector<int64_t> rstdOutShape{BS, 1, 1};
    std::vector<int64_t> xOutShape{BS, 1, H};
    std::vector<int64_t> balancedExpertIdsShape{BS, K};
    std::vector<int64_t> balancedActiveMaskShape{BS, K};

    int64_t xShapeSize = GetShapeSize(xShape);
    int64_t expertIdsShapeSize = GetShapeSize(expertIdsShape);
    int64_t scalesShapeSize = GetShapeSize(scalesShape);
    int64_t expertScalesShapeSize = GetShapeSize(expertScalesShape);
    int64_t expandXShapeSize = GetShapeSize(expandXShape);
    int64_t dynamicScalesShapeSize = GetShapeSize(dynamicScalesShape);
    int64_t expandIdxShapeSize = GetShapeSize(expandIdxShape);
    int64_t expertTokenNumsShapeSize = GetShapeSize(expertTokenNumsShape);
    int64_t epRecvCountsShapeSize = GetShapeSize(epRecvCountsShape);
    int64_t tpRecvCountsShapeSize = GetShapeSize(tpRecvCountsShape);
    int64_t expandScalesShapeSize = GetShapeSize(expandScalesShape);
    int64_t residualXShapeSize = GetShapeSize(residualXShape);
    int64_t sharedExpertXShapeSize = GetShapeSize(sharedExpertXShape);
    int64_t gammaShapeSize = GetShapeSize(gammaShape);
    int64_t yOutShapeSize = GetShapeSize(yOutShape);
    int64_t rstdOutShapeSize = GetShapeSize(rstdOutShape);
    int64_t xOutShapeSize = GetShapeSize(xOutShape);
    int64_t balancedExpertIdsShapeSize = GetShapeSize(balancedExpertIdsShape);
    int64_t balancedActiveMaskShapeSize = GetShapeSize(balancedActiveMaskShape);

    // 构造host侧变量
    std::vector<int16_t> xHostData(xShapeSize, 1);
    std::vector<int32_t> expertIdsHostData;
    for (int32_t token_id = 0; token_id < expertIdsShape[0]; token_id++) {
        // 每个token发给moe专家{0, 1, ... k - 1}
        for (int32_t k_id = 0; k_id < expertIdsShape[1]; k_id++) {
            expertIdsHostData.push_back(k_id);
        }
    }
    // 构造eplb_table数据:8个moe专家,每个专家有1个实例,每张卡部署一个moe专家实例,例如:前两个数1,0表示第1个moe专家部署1个实例,在place0
    std::vector<int32_t> eplbTableHostData = {1, 0, 1, 1, 1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7};

    std::vector<float> scalesHostData(scalesShapeSize, 0.1);
    std::vector<float> expertScalesHostData(expertScalesShapeSize, 0.1);
    std::vector<int16_t> expandXHostData(expandXShapeSize, 0);
    std::vector<float> dynamicScalesHostData(dynamicScalesShapeSize, 0);
    std::vector<int32_t> expandIdxHostData(expandIdxShapeSize, 0);
    std::vector<int64_t> expertTokenNumsHostData(expertTokenNumsShapeSize, 0);
    std::vector<int32_t> epRecvCountsHostData(epRecvCountsShapeSize, 0);
    std::vector<int32_t> tpRecvCountsHostData(tpRecvCountsShapeSize, 0);
    std::vector<float> expandScalesHostData(expandScalesShapeSize, 0);
    std::vector<int16_t> residualXHostData(residualXShapeSize, 1);
    std::vector<int16_t> sharedExpertXHostData(sharedExpertXShapeSize, 1);
    std::vector<int16_t> gammaHostData(gammaShapeSize, 1);
    std::vector<int16_t> yOutHostData(yOutShapeSize, 0);
    std::vector<float> rstdOutHostData(rstdOutShapeSize, 0);
    std::vector<int16_t> xOutHostData(xOutShapeSize, 0);
    std::vector<int32_t> balancedExpertIdsHostData(balancedExpertIdsShapeSize, 0);
    std::vector<int32_t> balancedActiveMaskHostData(balancedActiveMaskShapeSize, 0);

    // 构造device侧变量
    ret = CreateAclTensor(expertIdsHostData, expertIdsShape, &expertIdsDeviceAddr, aclDataType::ACL_INT32, &expertIds);
    ret = CreateAclTensor(eplbTableHostData, eplbTableShape, &eplbTableDeviceAddr, aclDataType::ACL_INT32, &eplbTable);
    ret = CreateAclTensor(balancedExpertIdsHostData, balancedExpertIdsShape, &balancedExpertIdsDeviceAddr, aclDataType::ACL_INT32, &balancedExpertIds);
    ret = CreateAclTensor(balancedActiveMaskHostData, balancedActiveMaskShape, &balancedActiveMaskDeviceAddr, aclDataType::ACL_BOOL, &balancedActiveMask);
    ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(scalesHostData, scalesShape, &scalesDeviceAddr, aclDataType::ACL_FLOAT, &scales);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(expertScalesHostData, expertScalesShape, &expertScalesDeviceAddr, aclDataType::ACL_FLOAT, &expertScales);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(expandXHostData, expandXShape, &expandXDeviceAddr, (quantMode > 0) ? aclDataType::ACL_INT8 : aclDataType::ACL_BF16, &expandX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dynamicScalesHostData, dynamicScalesShape, &dynamicScalesDeviceAddr, aclDataType::ACL_FLOAT, &dynamicScales);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
        ret = CreateAclTensor(expandIdxHostData, expandIdxShape, &expandIdxDeviceAddr, aclDataType::ACL_INT32, &expandIdx);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(expertTokenNumsHostData, expertTokenNumsShape, &expertTokenNumsDeviceAddr, aclDataType::ACL_INT64, &expertTokenNums);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(epRecvCountsHostData, epRecvCountsShape, &epRecvCountsDeviceAddr, aclDataType::ACL_INT32, &epRecvCounts);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(tpRecvCountsHostData, tpRecvCountsShape, &tpRecvCountsDeviceAddr, aclDataType::ACL_INT32, &tpRecvCounts);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(expandScalesHostData, expandScalesShape, &expandScalesDeviceAddr, aclDataType::ACL_FLOAT, &expandScales);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(residualXHostData, residualXShape, &residualXDeviceAddr, aclDataType::ACL_BF16, &residualX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(sharedExpertXHostData, sharedExpertXShape, &sharedExpertXDeviceAddr, aclDataType::ACL_BF16, &sharedExpertX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gammaHostData, gammaShape, &gammaDeviceAddr, aclDataType::ACL_BF16, &gamma);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(yOutHostData, yOutShape, &yOutDeviceAddr, aclDataType::ACL_BF16, &yOut);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(rstdOutHostData, rstdOutShape, &rstdOutDeviceAddr, aclDataType::ACL_FLOAT, &rstdOut);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(xOutHostData, xOutShape, &xOutDeviceAddr, aclDataType::ACL_BF16, &xOut);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    
    /* 声明算子执行必需变量 */
    uint64_t eplbworkspaceSize = 0;
    aclOpExecutor *eplbexecutor = nullptr;
    void *eplbWorkspaceAddr = nullptr;

    uint64_t dispatchWorkspaceSize = 0;
    aclOpExecutor *dispatchExecutor = nullptr;
    void *dispatchWorkspaceAddr = nullptr;

    uint64_t combineWorkspaceSize = 0;
    aclOpExecutor *combineExecutor = nullptr;
    void *combineWorkspaceAddr = nullptr;

    /**************************************** 调用eplb ********************************************/
    ret = aclnnMoeUpdateExpertGetWorkspaceSize(expertIds, eplbTable, nullptr, nullptr, nullptr, args.epRankId,
        EP_WORLD_SIZE, balanceMode, balancedExpertIds, balancedActiveMask, &eplbworkspaceSize, &eplbexecutor);

    CHECK_RET(ret == ACL_SUCCESS,
            LOG_PRINT("[ERROR] aclnnMoeUpdateExpertGetWorkspaceSize failed. ret = %d \n", ret); return ret);

    if (eplbworkspaceSize > 0) {
        ret = aclrtMalloc(&eplbWorkspaceAddr, eplbworkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
    }
    // 调用第二阶段接口
    ret = aclnnMoeUpdateExpert(eplbWorkspaceAddr, eplbworkspaceSize, eplbexecutor, args.eplbStream);
    ret = aclrtSynchronizeStreamWithTimeout(args.eplbStream, 10000);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeUpdateExpert failed. ret = %d \n", ret);  \
        return ret);

    /**************************************** 调用dispatch ********************************************/

    ret = aclnnMoeDistributeDispatchV2GetWorkspaceSize(x, balancedExpertIds, (quantMode > 0 ? scales : nullptr), nullptr, 
            expertScales, hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE,
            args.tpRankId, expertShardType, sharedExpertNum,sharedExpertRankNum, quantMode, globalBS,
            expertTokenNumsType, nullptr, expandX, dynamicScales, expandIdx, expertTokenNums, epRecvCounts,
            tpRecvCounts, expandScales, &dispatchWorkspaceSize, &dispatchExecutor);
    
    CHECK_RET(ret == ACL_SUCCESS,
        LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2GetWorkspaceSize failed. ret = %d \n", ret); return ret);

    if (dispatchWorkspaceSize > 0) {
        ret = aclrtMalloc(&dispatchWorkspaceAddr, dispatchWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc workspace failed. ret = %d \n", ret); return ret);
    }
    // 调用第二阶段接口
    ret = aclnnMoeDistributeDispatchV2(dispatchWorkspaceAddr, dispatchWorkspaceSize,
                                        dispatchExecutor, args.dispatchStream);
    ret = aclrtSynchronizeStreamWithTimeout(args.dispatchStream, 10000);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeDispatchV2 failed. ret = %d \n", ret);  \
        return ret);
    
    /**************************************** 调用combine ********************************************/
    
    //调用combine算子第一阶段接口
    ret = aclnnMoeDistributeCombineV2GetWorkspaceSize(expandX, expertIds, expandIdx, epRecvCounts, expertScales,
        tpRecvCounts, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
        hcomEpName, EP_WORLD_SIZE, args.epRankId, moeExpertNum, hcomTpName, TP_WORLD_SIZE, args.tpRankId,
        expertShardType, sharedExpertNum, sharedExpertRankNum, globalBS, outDtype, commQuantMode,
        groupList_type, nullptr, x, &combineWorkspaceSize, &combineExecutor);
    CHECK_RET(
        ret == ACL_SUCCESS,
        LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV2GetWorkspaceSize failed. ret = %d \n", ret); return ret
    );
    // 根据combine算子第一阶段接口计算出的workspaceSize申请device内存
    if (combineWorkspaceSize > 0) {
        ret = aclrtMalloc(&combineWorkspaceAddr, combineWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtMalloc failed. ret = %d \n", ret);
                return ret);
    }

    // 调用combine算子第二阶段接口
    ret = aclnnMoeDistributeCombineV2(combineWorkspaceAddr, combineWorkspaceSize, combineExecutor, args.combineStream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclnnMoeDistributeCombineV2 failed. ret = %d \n", ret);
            return ret);

    LOG_PRINT("[INFO] device_%d aclnnMoeUpdateExpert, aclnnMoeDistributeDispatchV2 and aclnnMoeDistributeCombineV2    \
                execute successfully.\n", args.rankId);

    // 释放device资源
    if (dispatchWorkspaceSize > 0) {
        aclrtFree(dispatchWorkspaceAddr);
    }
    if (combineWorkspaceSize > 0) {
        aclrtFree(combineWorkspaceAddr);
    }
    if (x != nullptr) {
        aclDestroyTensor(x);
    }
    if (expertIds != nullptr) {
        aclDestroyTensor(expertIds);
    }
    if (eplbTable != nullptr) {
        aclDestroyTensor(eplbTable);
    }
    if (scales != nullptr) {
        aclDestroyTensor(scales);
    }
    if (expertScales != nullptr) {
        aclDestroyTensor(expertScales);
    }
    if (expandX != nullptr) {
        aclDestroyTensor(expandX);
    }
    if (dynamicScales != nullptr) {
        aclDestroyTensor(dynamicScales);
    }
    if (expandIdx != nullptr) {
        aclDestroyTensor(expandIdx);
    }
    if (expertTokenNums != nullptr) {
        aclDestroyTensor(expertTokenNums);
    }
    if (epRecvCounts != nullptr) {
        aclDestroyTensor(epRecvCounts);
    }
    if (tpRecvCounts != nullptr) {
        aclDestroyTensor(tpRecvCounts);
    }
    if (expandScales != nullptr) {
        aclDestroyTensor(expandScales);
    }
    if (residualX != nullptr) {
        aclDestroyTensor(residualX);
    }
    if (sharedExpertX != nullptr) {
        aclDestroyTensor(sharedExpertX);
    }
    if (gamma != nullptr) {
        aclDestroyTensor(gamma);
    }
    if (yOut != nullptr) {
        aclDestroyTensor(yOut);
    }
    if (rstdOut != nullptr) {
        aclDestroyTensor(rstdOut);
    }
    if (xOut != nullptr) {
        aclDestroyTensor(xOut);
    }
    if (balancedExpertIds != nullptr) {
        aclDestroyTensor(balancedExpertIds);
    }
    if (balancedActiveMask != nullptr) {
        aclDestroyTensor(balancedActiveMask);
    }
    if (xDeviceAddr != nullptr) {
        aclrtFree(xDeviceAddr);
    }
    if (expertIdsDeviceAddr != nullptr) {
        aclrtFree(expertIdsDeviceAddr);
    }
    if (eplbTableDeviceAddr != nullptr) {
        aclrtFree(eplbTableDeviceAddr);
    }
    if (scalesDeviceAddr != nullptr) {
        aclrtFree(scalesDeviceAddr);
    }
    if (expertScalesDeviceAddr != nullptr) {
        aclrtFree(expertScalesDeviceAddr);
    }
    if (expandXDeviceAddr != nullptr) {
        aclrtFree(expandXDeviceAddr);
    }
    if (dynamicScalesDeviceAddr != nullptr) {
        aclrtFree(dynamicScalesDeviceAddr);
    }
    if (expandIdxDeviceAddr != nullptr) {
        aclrtFree(expandIdxDeviceAddr);
    }
    if (expertTokenNumsDeviceAddr != nullptr) {
        aclrtFree(expertTokenNumsDeviceAddr);
    }
    if (epRecvCountsDeviceAddr != nullptr) {
        aclrtFree(epRecvCountsDeviceAddr);
    }
    if (expandScalesDeviceAddr != nullptr) {
        aclrtFree(expandScalesDeviceAddr);
    }
    if (tpRecvCountsDeviceAddr != nullptr) {
        aclrtFree(tpRecvCountsDeviceAddr);
    }
    if (residualXDeviceAddr != nullptr) {
        aclrtFree(residualXDeviceAddr);
    }
    if (sharedExpertXDeviceAddr != nullptr) {
        aclrtFree(sharedExpertXDeviceAddr);
    }
    if (gammaDeviceAddr != nullptr) {
        aclrtFree(gammaDeviceAddr);
    }
    if (yOutDeviceAddr != nullptr) {
        aclrtFree(yOutDeviceAddr);
    }
    if (rstdOutDeviceAddr != nullptr) {
        aclrtFree(rstdOutDeviceAddr);
    }
    if (xOutDeviceAddr != nullptr) {
        aclrtFree(xOutDeviceAddr);
    }
    if (balancedExpertIdsDeviceAddr != nullptr) {
        aclrtFree(balancedExpertIdsDeviceAddr);
    }
    if (balancedActiveMaskDeviceAddr != nullptr) {
        aclrtFree(balancedActiveMaskDeviceAddr);
    }
    HcclCommDestroy(args.hcclEpComm);
    HcclCommDestroy(args.hcclTpComm);
    aclrtDestroyStream(args.eplbStream);
    aclrtDestroyStream(args.dispatchStream);
    aclrtDestroyStream(args.combineStream);
    aclrtDestroyContext(args.context);
    aclrtResetDevice(args.rankId);

    return 0;
}

int main(int argc, char *argv[])
{
    int ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclInit failed, ret = %d\n", ret); return ret);

    aclrtStream eplbStream[DEV_NUM];
    aclrtStream dispatchStream[DEV_NUM];
    aclrtStream combineStream[DEV_NUM];
    aclrtContext context[DEV_NUM];
    for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
        ret = aclrtSetDevice(rankId);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtSetDevice failed, ret = %d\n", ret); return ret);
        ret = aclrtCreateContext(&context[rankId], rankId);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateContext failed, ret = %d\n", ret); return ret);
        ret = aclrtCreateStream(&eplbStream[rankId]);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
        ret = aclrtCreateStream(&dispatchStream[rankId]);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
        ret = aclrtCreateStream(&combineStream[rankId]);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("[ERROR] aclrtCreateStream failed, ret = %d\n", ret); return ret);
    }

    int32_t devicesEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
    for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
        for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
            devicesEp[tpId][epId] = epId * TP_WORLD_SIZE + tpId;
        }
    }

    HcclComm commsEp[TP_WORLD_SIZE][EP_WORLD_SIZE];
    for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
        ret = HcclCommInitAll(EP_WORLD_SIZE, devicesEp[tpId], commsEp[tpId]);
        CHECK_RET(ret == ACL_SUCCESS,
                    LOG_PRINT("[ERROR] HcclCommInitAll ep %d failed, ret %d\n", tpId, ret); return ret);
    }

    int32_t devicesTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
    for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
        for (int32_t tpId = 0; tpId < TP_WORLD_SIZE; tpId++) {
            devicesTp[epId][tpId] = epId * TP_WORLD_SIZE + tpId;
        }
    }

    HcclComm commsTp[EP_WORLD_SIZE][TP_WORLD_SIZE];
    for (int32_t epId = 0; epId < EP_WORLD_SIZE; epId++) {
        ret = HcclCommInitAll(TP_WORLD_SIZE, devicesTp[epId], commsTp[epId]);
        CHECK_RET(ret == ACL_SUCCESS,
                    LOG_PRINT("[ERROR] HcclCommInitAll tp %d failed, ret %d\n", epId, ret); return ret);
    }

    Args args[DEV_NUM];
    // 各线程调用各卡执行算子
    std::vector<std::unique_ptr<std::thread>> threads(DEV_NUM);
    for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
        uint32_t epRankId = rankId / TP_WORLD_SIZE;
        uint32_t tpRankId = rankId % TP_WORLD_SIZE;

        args[rankId].rankId = rankId;
        args[rankId].epRankId = epRankId;
        args[rankId].tpRankId = tpRankId;
        args[rankId].hcclEpComm = commsEp[tpRankId][epRankId];
        args[rankId].hcclTpComm = commsTp[epRankId][tpRankId];
        args[rankId].eplbStream = eplbStream[rankId];
        args[rankId].dispatchStream = dispatchStream[rankId];
        args[rankId].combineStream = combineStream[rankId];
        args[rankId].context = context[rankId];
        threads[rankId].reset(new(std::nothrow) std::thread(&LaunchOneProcessUpdateExpertAndDispatchAndCombine, std::ref(args[rankId])));
    }

    for (uint32_t rankId = 0; rankId < DEV_NUM; rankId++) {
        threads[rankId]->join();
    }

    aclFinalize();
    LOG_PRINT("[INFO] aclFinalize success\n");

    return 0;
}
```