MoeTokenUnpermuteWithEpGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:aclnnMoeTokenUnpermuteWithEp的反向传播。
-
计算公式:
sortedIndices=sortedIndices[sortedIndices[rangeOptional[0]]<=i<sortedIndices[rangeOptional[1]]]sortedIndices= sortedIndices[sortedIndices[rangeOptional[0]]<=i<sortedIndices[rangeOptional[1]]]
-
probs非None:
unpermutedTokens[i]=permutedTokensOptional[sortedIndices[i]]unpermutedTokens[i] = permutedTokensOptional[sortedIndices[i]]
unpermutedTokens=unpermutedTokens.reshape(−1,topkNum,hiddenSize)unpermutedTokens = unpermutedTokens.reshape(-1, topkNum, hiddenSize)
unpermutedTokens=unpermutedTokensGrad.unsqueeze(1)∗unpermutedTokensunpermutedTokens = unpermutedTokensGrad.unsqueeze(1) * unpermutedTokens
probsGrad=∑k=0topkNum(unpermutedTokensi,j,k)probsGrad = \sum_{k=0}^{topkNum}(unpermutedTokens_{i,j,k})
permutedTokensGradOut[sortedIndices[i]]=((unpermutedTokensGrad.unsqueeze(1)∗probs.unsqueeze(−1)).reshape(−1,hiddenSize))[i]permutedTokensGradOut[sortedIndices[i]] = ((unpermutedTokensGrad.unsqueeze(1) * probs.unsqueeze(-1)).reshape(-1, hiddenSize))[i]
-
probs为None:
permutedTokensGradOut[sortedIndices[i]]=unpermutedTokensGrad[i]permutedTokensGradOut[sortedIndices[i]] = unpermutedTokensGrad[i]
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| unpermutedTokensGrad | 输入 | 正向输出unpermutedTokens的梯度,对应公式中的`unpermutedTokensGrad`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| sortedIndices | 输入 | 对应公式中的`sortedIndices`。 | INT32 | ND |
| permutedTokensOptional | 可选输入 | 对应公式中的`permutedTokensOptional`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| probsOptional | 可选输入 | 对应公式中的`probsOptional`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| paddedMode | 属性 | 目前仅支持false。 | BOOL | - |
| restoreShapeOptional | 属性 | 目前仅支持nullptr。 | aclIntArray* | - |
| rangeOptional | 属性 | ep切分的有效范围,对应公式中的`rangeOptional`。 | aclIntArray* | - |
| topkNum | 属性 | 每个token被选中的专家个数,对应公式中的`topkNum`。 | INT64 | - |
| permutedTokensGradOut | 输出 | 输入permutedTokens的梯度,对应公式中的`permutedTokensGradOut`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| probsGradOut | 可选输出 | 输入probs的梯度,对应公式中的`probsGradOut`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
约束说明
- topkNum <= 512。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_moe_token_unpermute_with_ep_grad | 通过aclnnMoeTokenUnpermuteWithEpGrad接口方式调用MoeTokenUnpermuteWithEpGrad算子。 |