MegaMoe

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

算子功能:该算子将 MoE(Mixture of Experts)以及 FFN(Feed-Forward Network)的完整计算流程融合为单个算子,实现 Dispatch + GroupMatmul1 + SwiGLUQuant + GroupMatmul2 + Combine 的端到端融合计算。

计算公式:

MegaMoe 是一个多阶段融合的 MoE 算子,在专家并行场景下依次完成 Token 路由分发与量化、分组矩阵乘法 + SwiGLU 激活、二次分组矩阵乘法与跨 Rank 聚合、以及加权求和,完整计算流程可以分解为以下四个阶段。

第一阶段对输入 Token 按专家分组收集后做 MXFP8 量化,生成各专家的量化输入与缩放因子:

X^e, SX,e=QMX ⁣(X[Te]),e=0,1,…,Elocal−1\hat{X}_e,\ S_{X,e} = \mathrm{Q}_{\text{MX}}\!\left(X[\mathcal{T}_e]\right), \quad e = 0, 1, \ldots, E_{\text{local}}-1

说明:根据 topkIds 将 Token 按专家排序收集,Te\mathcal{T}_e 为分配到专家 ee 的 Token 索引集合,ElocalE_{local} 表示当前专家收到的最大 Token 数,每个专家数值可能不同,X[Te]X[\mathcal{T}_e] 为对应的子矩阵。QMX\mathrm{Q}_{\text{MX}} 表示 MX 逐组量化(group size = 32),对每组 32 个元素提取共享指数后量化为 FP8 目标类型(FLOAT8_E5M2 或 FLOAT8_E4M3FN),同时输出 FLOAT8_E8M0 缩放因子。量化后的数据将作为 GMM1 的输入。

第二阶段对每个专家执行 GMM1 矩阵乘法(将 W1W_1 沿列方向分为两半分别计算)、SwiGLU 激活和 MX 量化:

Ze(x)=DQMX(X^e,SX,e)⋅DQMX(W1,e(x),S1,e(x)),Ze(y)=DQMX(X^e,SX,e)⋅DQMX(W1,e(y),S1,e(y))Z_e^{(x)} = \mathrm{DQ}_{\text{MX}}(\hat{X}_e, S_{X,e}) \cdot \mathrm{DQ}_{\text{MX}}(W_{1,e}^{(x)}, S_{1,e}^{(x)}), \quad Z_e^{(y)} = \mathrm{DQ}_{\text{MX}}(\hat{X}_e, S_{X,e}) \cdot \mathrm{DQ}_{\text{MX}}(W_{1,e}^{(y)}, S_{1,e}^{(y)})

Ue=Ze(x)⊙σ ⁣(Ze(x))⊙Ze(y)U_e = Z_e^{(x)} \odot \sigma\!\left(Z_e^{(x)}\right) \odot Z_e^{(y)}

U^e, SU,e=QMX(Ue)\hat{U}_e,\ S_{U,e} = \mathrm{Q}_{\text{MX}}(U_e)

说明:将 W1W_1 的前 N/2N/2W1,e(x)W_{1,e}^{(x)} 和后 N/2N/2W1,e(y)W_{1,e}^{(y)} 分别与 MX 反量化后的输入做矩阵乘法,得到 Swish 分支 Ze(x)Z_e^{(x)} 和门控分支 Ze(y)Z_e^{(y)}。SwiGLU 激活对两个分支做逐元素乘积 x⋅σ(x)⋅yx \cdot \sigma(x) \cdot y,其中 σ\sigma 为 Sigmoid 函数,将中间维度从 NN 减半为 N/2N/2。随后对 SwiGLU 输出做 MX 量化,得到 GMM2 的量化输入 U^e\hat{U}_e

第三阶段对每个专家执行 GMM2 矩阵乘法,并将结果按目标 Rank 分发:

Oe=DQMX(U^e,SU,e)⋅DQMX(W2,e,S2,e)O_e = \mathrm{DQ}_{\text{MX}}(\hat{U}_e, S_{U,e}) \cdot \mathrm{DQ}_{\text{MX}}(W_{2,e}, S_{2,e})

说明:将量化后的 SwiGLU 输出与第二组权重 W2W_2 做 MX 反量化后的矩阵乘法,将 N/2N/2 维中间表示映射回 HH 维隐藏空间,得到每个专家的输出 OeO_e。计算完成后通过 RDMA peermem 将结果按目标 Rank 的专家偏移地址写入远端,实现跨 Rank 聚合。

第四阶段对所有 Token 按路由权重加权求和,恢复为与输入相同形状的输出:

Y[i]=∑k=0K−1W[i, k]⋅O[π(i, k)]Y[i] = \sum_{k=0}^{K-1} W[i,\, k] \cdot O[\pi(i,\, k)]

说明:对每个 Token ii,根据排序后的路由索引 π(i,k)\pi(i,k) 从聚合后的专家结果中取出对应行,按 topkWeights 中的权重逐元素加权累加,得到最终输出 YY

其中,XX 表示参数 xWW 表示参数 topkWeightsW1W_1 表示参数 weight1W2W_2 表示参数 weight2YY 表示参数 yElocalE_{\text{local}} 表示属性 moeExpertNum / epWorldSize(每个 Rank 的专家数),KK 表示 topkIds 的第二维度(top-K 值,取值 6 或 8),NN 表示 hiddenDim(weight1.dim1)

局部变量说明:

  • Te\mathcal{T}_e:被路由到专家 ee 的 Token 索引集合,由 topkIds 排序后确定。
  • X^e, SX,e\hat{X}_e,\ S_{X,e}:专家 ee 的量化输入及其 MX 缩放因子,第一阶段中间结果。
  • W1,e(x)W_{1,e}^{(x)}W1,e(y)W_{1,e}^{(y)}W1W_1 对应专家 ee 的前 N/2N/2 列和后 N/2N/2 列子矩阵,由 weight1 按 SwiGLU 拆分推导。
  • S1,e(x)S_{1,e}^{(x)}S1,e(y)S_{1,e}^{(y)}W1,e(x)W_{1,e}^{(x)}W1,e(y)W_{1,e}^{(y)} 对应的 MX 缩放因子,从 weightScales1 按维度截取。
  • S2,eS_{2,e}W2,eW_{2,e} 对应的 MX 缩放因子,来自参数 weightScales2
  • Ze(x), Ze(y)Z_e^{(x)},\ Z_e^{(y)}:GMM1 的两路矩阵乘法输出(Swish 分支和门控分支),中间结果。
  • UeU_e:SwiGLU 激活输出,维度 me×N/2m_e \times N/2,中间结果。
  • U^e, SU,e\hat{U}_e,\ S_{U,e}:量化后的 SwiGLU 输出及其 MX 缩放因子,中间结果。
  • OeO_e:GMM2 的专家级输出,维度 me×Hm_e \times H,中间结果。
  • π(i,k)\pi(i, k):Token ii 的第 kk 个 top-k 专家在展开排序后的行索引,由路由排序确定。
  • QMX(⋅)\mathrm{Q}_{\text{MX}}(\cdot):MX 逐组量化操作,block size = 32,输出 FP8 数据和 E8M0 缩放因子。
  • DQMX(⋅)\mathrm{DQ}_{\text{MX}}(\cdot):MX 逐组反量化操作,在 matmul 内部隐式执行。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
context 输入 本卡通信域信息数据。 INT32 ND
x 输入 本卡发送的 token 数据。 BF16 ND
topk_ids 输入 每个 token 的 topK 个专家索引。 INT32 ND
topk_weights 输入 每个 token 的 topK 个专家权重。 FP32、BF16 ND
weight1 输入 GroupMatmul1 计算的右矩阵,用于计算 SwiGLU 激活前的线性变换。 FP8_E5M2、FP8_E4M3 ND
weight2 输入 GroupMatmul2 计算的右矩阵,用于 SwiGLU 激活后的线性变换。 FP8_E5M2、FP8_E4M3 ND
weight_scales1 可选输入 量化场景需要,GroupMatmul1 右矩阵反量化参数。 FP8_E8M0 ND
weight_cales2 可选输入 量化场景需要,GroupMatmul2 右矩阵反量化参数。 FP8_E8M0 ND
x_active_mask 可选输入 预留参数,暂不支持。 INT8 ND
scales 可选输入 预留参数,暂不支持。 FP32、FP8_E8M0 ND
moe_expert_num 属性 MoE 模型的总专家数量。 INT64
ep_world_size 属性 专家并行通信域大小。 INT64
ccl_buffer_size 属性 CCL 通信缓冲区大小。 INT64
max_recv_token_num 可选属性 每个 Rank 最大可接收 Token 数。默认值为0。该值为0时,会按最大值 bs*ep_world_size*min(topK,expertPerRank) 预留内存大小;不为0时,将按照输入值大小预留内存,该场景下需用户保证填入值大于等于每个rank最大可接收的 Token 数。 INT64
dispatch_quant_mode 可选属性 dispatch 通信时量化模式。目前仅支持4(MXFP模式)。默认值为0。 INT64
dispatch_quant_out_type 可选属性 dispatch量化后输出的数据类型。支持 23(FP8_E5M2)、24(FP8_E4M3)。 INT64
combine_quant_mode 可选属性 预留参数,暂不支持。默认值为0。 INT64
comm_alg 可选属性 预留参数,暂不支持。默认值为""。 STRING
global_bs 可选属性 全局 batch size,多卡场景下的总 Token 数,默认值为 0 表示使用单卡 BS。默认值按最大值 maxBs*ep_world_size 计算。 INT64
y 输出 计算输出结果,与输入x shape相同。 BF16 ND
expert_token_nums 输出 每个专家收到的 token 数量。 INT32 ND

约束说明

  • 参数一致性约束

    • 调用算子过程中使用的 ep_world_sizeglobal_bsHCCL_BUFFSIZE等参数取值,所有卡需保持一致,网络中不同层中也需保持一致。
  • 通信域使用约束

    • 仅支持 EP域,无 TP 域,不支持groupTptpWorldSizetpRankId属性。
    • 所有卡的 moe_expert_numep_world_sizeccl_buffer_sizemax_recv_token_numdispatch_quant_modedispatch_quant_out_typeglobal_bs 参数取值需保持一致。
  • 参数约束

    • BS(x.dim0)范围 [1, 512]。
    • H(x.dim1)仅支持 4096、5120、7168。
    • topK(topkIds.dim1)仅支持 6 或 8。
    • expertPerRank(weight1.dim0)范围 [1, 16]。
    • hiddenDim(weight1.dim1)仅支持 1024、2048、3072、4096、7168。
    • epWorldSize 范围 [2, 768]。
    • moeExpertNum 范围 [epWorldSize, 1024],且 moeExpertNum % epWorldSize == 0。
    • maxRecvTokenNum 范围 [0, BS × epWorldSize × min(topK, expertPerRank)]。
    • dispatchQuantOutType 仅支持 23(FLOAT8_E5M2)或 24(FLOAT8_E4M3FN)。
    • globalBs 为 0 或满足 BS × epWorldSize <= globalBs 且 globalBs % epWorldSize == 0。
    • 当前版本仅支持 MXFP 量化模式(dispatchQuantMode = 4),dispatch 阶段使用 MX 逐组量化(group size = 32),量化缩放因子类型为 FLOAT8_E8M0。
    • combineQuantMode 必须为 0,commAlg 必须为空字符串 ""。
    • y 的数据类型与 x 相同。
    • weight1 的 dim1(hiddenDim)必须等于 weight2 的 dim2 的二倍,这是因为 SwiGLU 激活需要将中间维度从 hiddenDim 减半为 hiddenDim/2。
    • expertPerRank = moeExpertNum / epWorldSize,必须为整数且在 [1, 16] 范围内。
  • MXFP量化场景约束

    • weight1 shape 为 (expertPerRank, hiddenDim, H),weight2 shape 为 (expertPerRank, H, hiddenDim/2)。
    • weightScales1 shape 为 (expertPerRank, hiddenDim, CeilDiv(H, 64), 2),其中 CeilDiv(H, 64) = (H + 63) / 64。
    • weightScales2 shape 为 (expertPerRank, H, CeilDiv(hiddenDim/2, 64), 2),其中 CeilDiv(hiddenDim/2, 64) = (hiddenDim/2 + 63) / 64。
    • weightScales1 的 dim3 和 weightScales2 的 dim3 必须等于 2。
    • MXFP 场景下,dispatchQuantOutType=23 时 weight1 和 weight2 必须为 FLOAT8_E5M2,dispatchQuantOutType=24 时必须为 FLOAT8_E4M3FN。

调用说明

调用方式 样例代码 说明
PyTorch接口调用 deepep.py 通过mega_moePyTorch接口方式调用mega_moe算子。