算子名称:FinalRouting
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A2 训练系列产品 | 是 |
功能说明
-
算子功能:用于MoE模型的最终路由阶段,将各个Expert计算的结果按照评分加权合并,得到每个Token的最终输出,完成MoE的Combine阶段。
-
计算公式: 对于每个Token
t:
out[t]=∑e(in[token_table[t,e]]⋅score_table[t,e])out[t] = \sum_{e} \big(in[token\_table[t, e]] \cdot score\_table[t, e]\big)
其中,仅当 token_table[t, e] >= 0 时参与计算。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| blockDim | 输入 | AI CORE的数量,比如:Ascend910B是40。 | int64_t | - |
| in | 输入 | expert输出张量, shape为(expert_num*token_num, hidden_size) | BFLOAT16 | ND |
| token_table | 输入 | token到expert的映射表, shape为(token_num, expert_num) | int32_t | ND |
| score_table | 输入 | 每个token在每个expert的评分,shape为(token_num, expert_num) | BFLOAT16 | ND |
| out | 输出 | 加权合并后的输出张量,shape为(token_num, hidden_size) | BFLOAT16 | ND |
约束说明
- token_table中小于0的值表示该expert对此token无效
调用说明
torch.ops.npu_ops_transformer_ext.final_routing(block_dim, input, token_table,score_table, output)