MoeGatingTopKBackward
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成MoE(Mixture of Experts)门控Top-K选择的反向梯度计算。该算子是MoeGatingTopK的反向算子,根据前向算子输出的归一化得分(xNorm)、上游梯度(gradY)和专家索引(expertIdx),计算输入得分矩阵的梯度(gradX)。支持sigmoid模式(normType=1)。
-
计算公式(sigmoid模式,normType=1):
- 缩放梯度
gradYScaledip=routedScalingFactor⋅gradYipgradYScaled_{ip} = routedScalingFactor \cdot gradY_{ip}
- 正向renorm的反向传播
wPrimeip=xNormi, expertIdxipwPrime_{ip} = xNorm_{i,\ expertIdx_{ip}}
Di=∑pwPrimeip+epsD_i = \sum_{p} wPrime_{ip} + eps
wip=wPrimeipDiw_{ip} = \frac{wPrime_{ip}}{D_i}
betai=∑pwip⋅gradYScaledipbeta_i = \sum_{p} w_{ip} \cdot gradYScaled_{ip}
gradWPrimeip=gradYScaledip−betaiDigradWPrime_{ip} = \frac{gradYScaled_{ip} - beta_i}{D_i}
- 散射到完整维度
gradNormXij=∑p: expertIdxip=jgradWPrimeipgradNormX_{ij} = \sum_{p:\ expertIdx_{ip}=j} gradWPrime_{ip}
- Sigmoid反向传播
gradXij=xNormij⋅(1−xNormij)⋅gradNormXijgradX_{ij} = xNorm_{ij} \cdot (1 - xNorm_{ij}) \cdot gradNormX_{ij}
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x_norm | 输入 | 计算的输入,对应公式中的xNorm。要求是一个2D的Tensor,维度为[M,N]。最后一维(专家数N)要求大于等于2,并小于等于2048。 | FLOAT32 | ND |
| grad_y | 输入 | 表示前向算子输出yOut的上游梯度,对应公式中的gradY。要求是一个2D的Tensor,维度为[M,K],K的范围要求大于等于1,并小于等于N。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| expert_idx | 输入 | 表示前向算子输出的top-k专家索引,对应公式中的expertIdx。shape要求与grad_y一致,维度为[M,K]。 | INT32 | ND |
| renorm | 可选属性 | 表示前向算子在softmax模式下renorm标记。0:不做renorm;1:需要做renorm;预留参数,当前仅支持sigmoid模式。 | INT64 | - |
| norm_type | 可选属性 | 表示norm函数类型。1表示使用Sigmoid函数,0表示Softmax函数。当前仅支持1。 | INT64 | - |
| routed_scaling_factor | 可选属性 | 表示前向计算中使用的routed_scaling_factor系数,对应公式中的routedScalingFactor。默认值为1.0。 | FLOAT32 | - |
| eps | 可选属性 | 表示前向计算使用的防除零常数,对应公式中的eps。默认值为1e-20。 | FLOAT32 | - |
| grad_x | 输出 | 表示前向算子输入参数x的梯度,对应公式中的gradX。数据类型与grad_y需要保持一致。shape与x_norm需要一致,维度为[M,N]。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
约束说明
无
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn API | test_aclnn_moe_gating_top_k_backward | 通过aclnnMoeGatingTopKBackward接口方式调用MoeGatingTopKBackward算子。 |