MoeTokenPermuteGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

算子功能:aclnnMoeTokenPermute的反向传播计算。

计算公式:

inputGrad=permutedOutputGrad.indexSelect(0,sortedIndices)inputGrad = permutedOutputGrad.indexSelect(0, sortedIndices)

inputGrad=inputGrad.reshape(−1,numTopk,hiddenSize)inputGrad = inputGrad.reshape(-1, numTopk, hiddenSize)

inputGrad=inputGrad.sum(dim=1)inputGrad = inputGrad.sum(dim = 1)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
permutedOutputGrad 输入 正向输出permutedTokens的梯度。 BFLOAT16、FLOAT16、FLOAT32 ND
sortedIndices 输入 排序的索引值。 INT32 ND
numTopk 属性 被选中的专家个数。 INT64 -
paddedMode 属性 pad模式的开关。 BOOL -
out 输出 输出token的梯度。 BFLOAT16、FLOAT16、FLOAT32 ND

约束说明

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:numTopk <= 512。

  • Ascend 950PR/Ascend 950DT: 在调用本接口时,框架内部会转调用aclnnMoeInitRoutingV2Grad接口,如果出现参数错误提示,请参考以下参数映射关系:

    • permutedOutputGrad输入等同于aclnnMoeInitRoutingV2Grad接口的gradExpandedX输入。
    • sortedIndices输入等同于aclnnMoeInitRoutingV2Grad接口的expandedRowIdx输入。
    • numTopk输入等同于aclnnMoeInitRoutingV2Grad接口的topK输入。
    • paddedMode输入等同于aclnnMoeInitRoutingV2Grad接口的dropPadMode输入。
    • out输出等同于aclnnMoeInitRoutingV2Grad接口的out输出。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品: 单卡通信量取值范围[2MB,100MB]。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_token_permute_grad.cpp 通过aclnnMoeTokenPermuteGrad接口方式调用MoeTokenPermuteGrad算子。