MoeTokenUnpermuteGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:aclnnMoeTokenUnpermute的反向传播。
-
计算公式:
-
probs非None:
unpermutedTokens[i]=permutedTokens[sortedIndices[i]]unpermutedTokens[i] = permutedTokens[sortedIndices[i]]
unpermutedTokens=unpermutedTokens.reshape(−1,topK_num,hiddenSize)unpermutedTokens = unpermutedTokens.reshape(-1, topK\_num, hiddenSize)
unpermutedTokens=unpermutedTokensGrad.unsqueeze(1)∗unpermutedTokensunpermutedTokens = unpermutedTokensGrad.unsqueeze(1) * unpermutedTokens
probsGrad=∑k=0K(unpermutedTokensi,j,k)probsGrad = \sum_{k=0}^{K}(unpermutedTokens_{i,j,k})
permutedTokensGrad[sortedIndices[i]]=((unpermutedTokensGrad.unsqueeze(1)∗probs.unsqueeze(−1)).reshape(−1,hiddenSize))[i]permutedTokensGrad[sortedIndices[i]] = ((unpermutedTokensGrad.unsqueeze(1) * probs.unsqueeze(-1)).reshape(-1, hiddenSize))[i]
-
probs为None:
permutedTokensGrad[sortedIndices[i]]=unpermutedOutputGrad[i]permutedTokensGrad[sortedIndices[i]] = unpermutedOutputGrad[i]
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| permutedTokens | 输入 | 输入token,对应公式中的`permutedTokens`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| unpermutedTokensGrad | 输入 | 正向输出unpermutedTokens的梯度,对应公式中的`unpermutedTokensGrad`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| sortedIndices | 输入 | 对应公式中的`sortedIndices`。 | INT32 | ND |
| probsOptional | 可选输入 | 对应公式中的`probs`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| paddedMode | 属性 | 目前仅支持false。 | BOOL | - |
| restoreShapeOptional | 属性 | 目前仅支持nullptr。 | aclIntArray* | - |
| permutedTokensGradOut | 输出 | 对应公式中的`permutedTokensGrad`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| probsGradOut | 输出 | 对应公式中的`probsGrad`。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
约束说明
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:topK_num <= 512。
- Ascend 950PR/Ascend 950DT:
在调用本接口时,框架内部会转调用aclnnMoeFinalizeRoutingV2Grad接口,如果出现参数错误提示,请参考以下参数映射关系:
- permutedTokens输入等同于aclnnMoeFinalizeRoutingV2Grad接口的expandedXOptional输入。
- unpermutedTokensGrad输入等同于aclnnMoeFinalizeRoutingV2Grad接口的gradY输入。
- sortedIndices输入等同于aclnnMoeFinalizeRoutingV2Grad接口的expandedRowIdx输入。
- probsOptional输入等同于aclnnMoeFinalizeRoutingV2Grad接口的scalesOptional输入。
- paddedMode输入等同于aclnnMoeFinalizeRoutingV2Grad接口的dropPadMode输入。
- permutedTokensGradOut输出等同于aclnnMoeFinalizeRoutingV2Grad接口的gradExpandedXOut输出。
- probsGradOut输出等同于aclnnMoeFinalizeRoutingV2Grad接口的gradScalesOut输出。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_moe_token_unpermute_grad | 通过aclnnMoeTokenUnpermuteGrad接口方式调用MoeTokenUnpermuteGrad算子。 |