MoeGatingTopK
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:MoE计算中,对输入x做Sigmoid或者SoftMax计算,对计算结果分组进行排序,最后根据分组排序的结果选取前k个专家。
-
计算公式:
对输入做Sigmoid或者SoftMax:
ifnormType==1:normOut=Sigmoid(x)else:normOut=SoftMax(x)if normType==1: normOut=Sigmoid(x) else: normOut=SoftMax(x)
如果bias不为空:
normValue=normOut+biasnormValue = normOut + bias
对计算结果按照groupCount进行分组,每组按照groupSelectMode取max或topk2的sum值对group进行排序,取前kGroup个组:
groupOut,groupId=TopK(ReduceSum(TopK(Split(normValue,groupCount),k=2,dim=−1),dim=−1),k=kGroup)groupOut, groupId = TopK(ReduceSum(TopK(Split(normValue, groupCount), k=2, dim=-1), dim=-1),k=kGroup)
根据上一步的groupId获取normValue中对应的元素,将数据再做TopK,得到expertIdxOut的结果:
y,expertIdxOut=TopK(normOut[groupId,:],k=k)y,expertIdxOut=TopK(normOut[groupId, :],k=k)
对y按照输入的routedScalingFactor和eps参数进行计算,得到yOut的结果:
yOut=y/(ReduceSum(y,dim=−1)+eps)∗routedScalingFactoryOut = y / (ReduceSum(y, dim=-1)+eps)*routedScalingFactor
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 待计算输入,对应公式中的`x`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| biasOptional | 输入 | 与输入x进行计算的bias值,对应公式中的`bias`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| k | 输入 | topk的k值,对应公式中的`k`。 | INT64 | - |
| kGroup | 输入 | 分组排序后取的group个数,对应公式中的`kGroup`。 | INT64 | - |
| groupCount | 输入 | 分组的总个数,对应公式中的`groupCount`。 | INT64 | - |
| routedScalingFactor | 输入 | 计算yOut使用的routedScalingFactor系数,对应公式中的`routedScalingFactor`。 | DOUBLE | - |
| eps | 输入 | 用于计算yOut使用的eps系数,对应公式中的`eps`。 | DOUBLE | - |
| yOut | 输出 | 对x做norm、分组排序topk后计算的结果,对应公式中的`yOut`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| expertIdxOut | 输出 | 对x做norm、分组排序topk后的索引,对应公式中的`expertIdxOut`。 | INT32 | ND |
| normOut | 输出 | norm计算的输出结果,对应公式中的`normOut`。 | FLOAT32 | ND |
| groupSelectMode | 输入 | 分组排序方式。 | INT64 | - |
| renorm | 输入 | renorm标记。 | INT64 | - |
| normType | 输入 | norm函数类型。 | INT64 | - |
| outFlag | 输入 | 表示是否输出norm操作结果。 | BOOL | - |
约束说明
- 输入shape限制:
- x最后一维(即专家数)要求不大于2048。
- 输入值域限制:
- 要求1 <= k <= x_shape[-1] / groupCount * kGroup。
- 要求1 <= kGroup <= groupCount,并且kGroup * x_shape[-1] / groupCount的值要大于等于k。
- 要求groupCount > 0,x_shape[-1]能够被groupCount整除且整除后的结果大于groupSelectMode,并且整除的结果按照32个数对齐后乘groupCount的结果不大于2048。
- renorm仅支持0,表示先进行norm操作,再计算topk。
- 其他限制:
- groupSelectMode取值0和1,0表示使用最大值对group进行排序, 1表示使用topk2的sum值对group进行排序。
- normType取值0和1,0表示使用Softmax函数,1表示使用Sigmoid函数。
- outFlag取值true和false,true表示输出,false表示不输出。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_moe_gating_top_k | 通过aclnnMoeGatingTopK接口方式调用MoeGatingTopK算子。 |
| 图模式 | test_geir_moe_gating_top_k | 通过算子IR构图方式调用MoeGatingTopK算子。 |