MoeTokenUnpermuteWithEpGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:aclnnMoeTokenUnpermuteWithEp的反向传播。

  • 计算公式:

    sortedIndices=sortedIndices[sortedIndices[rangeOptional[0]]<=i<sortedIndices[rangeOptional[1]]]sortedIndices= sortedIndices[sortedIndices[rangeOptional[0]]<=i<sortedIndices[rangeOptional[1]]]

  • probs非None:

    unpermutedTokens[i]=permutedTokensOptional[sortedIndices[i]]unpermutedTokens[i] = permutedTokensOptional[sortedIndices[i]]

    unpermutedTokens=unpermutedTokens.reshape(−1,topkNum,hiddenSize)unpermutedTokens = unpermutedTokens.reshape(-1, topkNum, hiddenSize)

    unpermutedTokens=unpermutedTokensGrad.unsqueeze(1)∗unpermutedTokensunpermutedTokens = unpermutedTokensGrad.unsqueeze(1) * unpermutedTokens

    probsGrad=∑k=0topkNum(unpermutedTokensi,j,k)probsGrad = \sum_{k=0}^{topkNum}(unpermutedTokens_{i,j,k})

    permutedTokensGradOut[sortedIndices[i]]=((unpermutedTokensGrad.unsqueeze(1)∗probs.unsqueeze(−1)).reshape(−1,hiddenSize))[i]permutedTokensGradOut[sortedIndices[i]] = ((unpermutedTokensGrad.unsqueeze(1) * probs.unsqueeze(-1)).reshape(-1, hiddenSize))[i]

  • probs为None:

    permutedTokensGradOut[sortedIndices[i]]=unpermutedTokensGrad[i]permutedTokensGradOut[sortedIndices[i]] = unpermutedTokensGrad[i]

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
unpermutedTokensGrad 输入 正向输出unpermutedTokens的梯度,对应公式中的`unpermutedTokensGrad`。 FLOAT16、BFLOAT16、FLOAT32 ND
sortedIndices 输入 对应公式中的`sortedIndices`。 INT32 ND
permutedTokensOptional 可选输入 对应公式中的`permutedTokensOptional`。 FLOAT16、BFLOAT16、FLOAT32 ND
probsOptional 可选输入 对应公式中的`probsOptional`。 FLOAT16、BFLOAT16、FLOAT32 ND
paddedMode 属性 目前仅支持false。 BOOL -
restoreShapeOptional 属性 目前仅支持nullptr。 aclIntArray* -
rangeOptional 属性 ep切分的有效范围,对应公式中的`rangeOptional`。 aclIntArray* -
topkNum 属性 每个token被选中的专家个数,对应公式中的`topkNum`。 INT64 -
permutedTokensGradOut 输出 输入permutedTokens的梯度,对应公式中的`permutedTokensGradOut`。 FLOAT16、BFLOAT16、FLOAT32 ND
probsGradOut 可选输出 输入probs的梯度,对应公式中的`probsGradOut`。 FLOAT16、BFLOAT16、FLOAT32 ND

约束说明

  • topkNum <= 512。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_token_unpermute_with_ep_grad 通过aclnnMoeTokenUnpermuteWithEpGrad接口方式调用MoeTokenUnpermuteWithEpGrad算子。