算子名称:FinalRouting

产品支持情况

产品 是否支持
Atlas A2 训练系列产品

功能说明

  • 算子功能:用于MoE模型的最终路由阶段,将各个Expert计算的结果按照评分加权合并,得到每个Token的最终输出,完成MoE的Combine阶段。

  • 计算公式: 对于每个Token t

out[t]=∑e(in[token_table[t,e]]⋅score_table[t,e])out[t] = \sum_{e} \big(in[token\_table[t, e]] \cdot score\_table[t, e]\big)

其中,仅当 token_table[t, e] >= 0 时参与计算。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
blockDim 输入 AI CORE的数量,比如:Ascend910B是40。 int64_t -
in 输入 expert输出张量, shape为(expert_num*token_num, hidden_size) BFLOAT16 ND
token_table 输入 token到expert的映射表, shape为(token_num, expert_num) int32_t ND
score_table 输入 每个token在每个expert的评分,shape为(token_num, expert_num) BFLOAT16 ND
out 输出 加权合并后的输出张量,shape为(token_num, hidden_size) BFLOAT16 ND

约束说明

  • token_table中小于0的值表示该expert对此token无效

调用说明

torch.ops.npu_ops_transformer_ext.final_routing(block_dim, input, token_table,score_table, output)