MoeTokenPermuteWithEpGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

算子功能:aclnnMoeTokenPermuteWithEp的反向传播计算。

计算公式:

sortedIndices=sortedIndices[rangeOptional[0]<=i<rangeOptional[1]]sortedIndices = sortedIndices[rangeOptional[0]<=i<rangeOptional[1]]

tokenGradOut=permutedTokensOutputGrad.indexSelect(0,sortedIndices)tokenGradOut = permutedTokensOutputGrad.indexSelect(0, sortedIndices)

tokenGradOut=tokenGradOut.reshape(−1,numTopk,hiddenSize)tokenGradOut = tokenGradOut.reshape(-1, numTopk, hiddenSize)

tokenGradOut=tokenGradOut.sum(dim=1)tokenGradOut = tokenGradOut.sum(dim = 1)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
permutedTokensOutputGrad 输入 正向输出permutedTokens的梯度,公式中的`permutedTokensOutputGrad`。 BFLOAT16、FLOAT16、FLOAT32 ND
sortedIndices 输入 正向输出的permuteTokensOut和正向输入的tokens的映射关系,公式中的`sortedIndices`。 INT32 ND
permutedProbsOutputGradOptional 输入 可选计算输入,与计算输出probsGradOut对应,传入空则不输出probsGradOut。 BFLOAT16、FLOAT16、FLOAT32 ND
numTopk 属性 被选中的专家个数。 INT64 -
rangeOptional 属性 ep切分的有效范围。 aclIntArray -
paddedMode 属性 true表示开启paddedMode,false表示关闭paddedMode,目前仅支持false。 BOOL -
tokenGradOut 输出 输入token的梯度。 BFLOAT16、FLOAT16、FLOAT32 ND
probsGradOut 输出 输入probs的梯度。 FLOAT、FLOAT16、BFLOAT16 ND

约束说明

  • numTopk <= 512。
  • 不支持paddedMode为True
  • 当rangeOptional为空时,忽略permutedProbsOutputGradOptional和probsGradOut,执行逻辑回退到aclnnMoeTokenPermuteGrad

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_token_permute_with_ep_grad.cpp 通过aclnnMoeTokenPermuteWithEpGrad接口方式调用MoeTokenPermuteWithEpGrad算子。