MoeTokenUnpermute
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
| Kirin X90 处理器系列产品 | √ |
| Kirin 9030 处理器系列产品 | √ |
功能说明
-
算子功能:根据sortedIndices存储的下标,获取permutedTokens中存储的输入数据;如果存在probs数据,permutedTokens会与probs相乘;最后进行累加求和,并输出计算结果。
-
计算公式:
-
probs非None时,计算公式如下:
T[k]=T[S[k]]T[k] = T[S[k]]
T[k]=T[k]∗P[i][j]T[k] = T[k] * P[i][j]
O[i]=∑k=i∗topK(i+1)∗topK−1T[k]O[i] = \sum_{k=i*topK}^{(i+1)*topK - 1 } T[k]
其中i∈0,1,...,tokens−1i \in {0,1,...,tokens-1};j∈0,1,...,topK−1j \in {0,1,...,topK-1};k∈0,1,...,tokens∗topK−1k \in {0,1,...,tokens*topK-1};T表示permutedTokens;S表示sortedIndices;P表示probs;O表示out;topK表示topK_num,表示处理每个token的专家个数;tokens表示tokens_num,表示输入token的个数。
-
probs为None时,此时topK_num=1,计算公式如下:
T[i]=T[S[i]]T[i] = T[S[i]]
O[i]=T[i]O[i] = T[i]
其中 i∈0,1,...,tokens−1i \in {0,1,...,tokens-1};T表示permutedTokens;S表示sortedIndices;O表示out;tokens表示tokens_num。
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| permutedTokens | 输入 | 待计算输入,对应公式中的`T`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| sortedIndices | 输入 | 表示需要计算的数据在permutedTokens中的位置,对应公式中的`S`。 | INT32 | ND |
| probsOptional | 可选输入 | 对应公式中的`P`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| paddedMode | 属性 | 目前仅支持false。 | BOOL | - |
| restoreShapeOptional | 属性 | 目前仅支持nullptr。 | aclIntArray* | - |
| out | 输出 | 对应公式中的`O`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
- Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16。
约束说明
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:topK_num <= 512。
- Ascend 950PR/Ascend 950DT:
在调用本接口时,框架内部会调用aclnnMoeFinalizeRoutingV2接口,如果出现参数错误提示,请参考以下参数映射关系:
- permutedTokens输入等同于aclnnMoeFinalizeRoutingV2接口的expandedX输入。
- sortedIndices输入等同于aclnnMoeFinalizeRoutingV2接口的expandedRowIdx输入。
- probsOptional输入等同于aclnnMoeFinalizeRoutingV2接口的scalesOptional输入。
- paddedMode输入等同于aclnnMoeFinalizeRoutingV2接口的dropPadMode输入。
- out输出等同于aclnnMoeFinalizeRoutingV2接口的out输出。
- Atlas 推理系列产品:
- permutedTokens与probsOptional支持的数据类型为FLOAT16、FLOAT32。
- topK_num <= 512。
- hiddensize是128的倍数且小于10240。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_moe_token_unpermute | 通过aclnnMoeTokenUnpermute接口方式调用MoeTokenUnpermute算子。 |