MoeTokenPermuteWithRoutingMapGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
算子功能:aclnnMoeTokenPermuteWithRoutingMap的反向传播。
计算公式:
permuteTokenId,outIndex=sortedIndices.sort(dim=−1)permuteTokenId, outIndex= sortedIndices.sort(dim=-1)
capacity=permutedTokenOutputGrad.size(0)/experts_numcapacity = permutedTokenOutputGrad.size(0) / experts\_num
-
probs不为None:
probsGradOutOptional=zeros(tokens_num,experts_num)probsGradOutOptional = zeros(tokens\_num, experts\_num)
- paddedMode为true时
probsGradOutOptional[sortedIndices[i],i/capacity]=permutedProbsOutputGradOptional[i]probsGradOutOptional [sortedIndices[i], i/capacity] = permutedProbsOutputGradOptional[i]
- paddedMode为false时
probsGradOutOptional=maskedscatter(probsGradOutOptional,routingMap,permutedProbsOutputGradOptional)probsGradOutOptional = maskedscatter(probsGradOutOptional,routingMap,permutedProbsOutputGradOptional)
-
probs为None:
tokensGradOut=zeros(restoreShape,dtype=permutedTokens.dtype,device=permutedTokens.device)tokensGradOut= zeros(restoreShape, dtype=permutedTokens.dtype, device=permutedTokens.device)
tokensGradOut[permuteTokenId[i]]+=permutedTokens[outIndex[i]]tokensGradOut[permuteTokenId[i]] += permutedTokens[outIndex[i]]
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| permutedTokenOutputGrad | 输入 | 正向输出permutedTokens的梯度。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| permutedProbsOutputGradOptional | 输入 | 可选输入,不传则表示不需要计算probsGradOutOptional。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| sortedIndices | 输入 | 非dropPadded模式要求shape为一个1D的(tokens_num \* topK_num,)。 | INT32 | ND |
| routingMap | 输入 | 代表token到expert的映射关系。 | INT8 | ND |
| experts_num | 属性 | 参与运算的专家个数。 | INT64 | - |
| tokens_num | 属性 | 参与运算的token个数。 | INT64 | - |
| dropAndPad | 属性 | true表示开启dropPaddedMode,false表示关闭dropPaddedMode。 | BOOL | - |
| tokensGradOut | 输出 | 输入permutedTokens的梯度。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
| probsGradOutOptional | 输出 | 输入probs的梯度,可选输出。 | BFLOAT16、FLOAT16、FLOAT32 | ND |
约束说明
- 非dropPaddedMode 场景topK_num <= 512。
- 不支持混合精度输入,即permutedTokenOutputGrad、permutedProbsOutputGradOptional、tokensGradOut、probsGradOutOptional需要保持相同的数据类型。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_moe_token_permute_with_routing_map_grad.cpp | 通过aclnnMoeTokenPermuteWithRoutingMap接口方式调用MoeTokenPermuteWithRoutingMap算子。 |