MoeTokenUnpermuteWithRoutingMap

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×
Kirin X90 处理器系列产品
Kirin 9030 处理器系列产品

功能说明

  • 算子功能:对经过aclnnMoeTokenpermuteWithRoutingMap处理的permutedTokens,累加回原unpermutedTokens。根据sortedIndices存储的下标,获取permutedTokens中存储的输入数据;如果存在probs数据,permutedTokens会与probs相乘,最后进行累加求和,并输出计算结果。

  • 计算公式:

    topK_num=permutedTokens.size(0)//routingMapOptional.size(0)topK\_num= permutedTokens.size(0) // routingMapOptional.size(0)

    numExperts=probs.size(1)numExperts = probs.size(1)

    numTokens=probs.size(0)numTokens = probs.size(0)

    capacity=sortedIndices.size(0)//numExpertscapacity = sortedIndices.size(0) // numExperts

    (1)probs不为None,padMode为true时:

    permutedProbs[i//capacity,sortedIndices[i]]=probs[i]permutedProbs [i//capacity,sortedIndices[i]]=probs[i]

    permutedTokens=permutedTokens∗permutedProbspermutedTokens = permutedTokens * permutedProbs

    unpermutedTokens=zeros(restoreShape,dtype=permutedTokens.dtype,device=permutedTokens.device)unpermutedTokens= zeros(restoreShape, dtype=permutedTokens.dtype, device=permutedTokens.device)

    permuteTokenId,outIndex=sortedIndices.sort(dim=−1)permuteTokenId, outIndex= sortedIndices.sort(dim=-1)

    unpermutedTokens[permuteTokenId[i]]+=permutedTokens[outIndex[i]]unpermutedTokens[permuteTokenId[i]] += permutedTokens[outIndex[i]]

    (2)probs不为None,padMode为false时:

    permutedProbs=probs.T.maskedSelect(routingMap.T)permutedProbs = probs.T.maskedSelect(routingMap.T)

    permutedTokens=permutedTokens∗permutedProbspermutedTokens = permutedTokens * permutedProbs

    unpermutedTokens=zeros(restoreShape,dtype=permutedTokens.dtype,device=permutedTokens.device)unpermutedTokens= zeros(restoreShape, dtype=permutedTokens.dtype, device=permutedTokens.device)

    unpermutedTokens[i//topK_num]+=permutedTokens[sortedIndices[i]]unpermutedTokens[i//topK\_num] += permutedTokens[sortedIndices[i]]

    (3)probs为None,padMode为true时:

    permuteTokenId,outIndex=sortedIndices.sort(dim=−1)permuteTokenId, outIndex= sortedIndices.sort(dim=-1)

    unpermutedTokens[permuteTokenId[i]]+=permutedTokens[outIndex[i]]unpermutedTokens[permuteTokenId[i]] += permutedTokens[outIndex[i]]

    (4)probs为None,padMode为false时:

    unpermutedTokens[i//topK_num]+=permutedTokens[sortedIndices[i]]unpermutedTokens[i//topK\_num] += permutedTokens[sortedIndices[i]]

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
permutedTokens 输入 待计算输入,对应公式中的`permutedTokens`。 FLOAT16、BFLOAT16、FLOAT32 ND
sortedIndices 输入 对应公式中的`sortedIndices`。 INT32 ND
routingMapOptional 可选输入 代表对应位置的Token是否被对应专家处理,对应公式中的`routingMapOptional`。 INT8、BOOL ND
probsOptional 可选输入 代表对应位置的Token被对应专家处理后的结果在最终结果中的权重,对应公式中的`probs`。 FLOAT16、BFLOAT16、FLOAT32 ND
paddedMode 属性 true表示开启paddedMode,false表示关闭paddedMode。开启paddedMode时,每个专家固定处理capacity个token。关闭paddedMode时,每个token固定被topK_num个专家处理。 BOOL -
restoreShapeOptional 属性 代表unpermutedTokens的shape。 aclIntArray* -
unpermutedTokens 输出 对应公式中的`unpermutedTokens`。 FLOAT16、BFLOAT16、FLOAT32 ND
outIndex 输出 对应公式中的`outIndex`。 FLOAT16、BFLOAT16、FLOAT32 ND
permuteTokenId 输出 对应公式中的`permuteTokenId`。 FLOAT16、BFLOAT16、FLOAT32 ND
permuteProbs 输出 表示输出经过排序后的probs,对应公式中的`permutedProbs`。 FLOAT16、BFLOAT16、FLOAT32 ND
  • Kirin X90/Kirin 9030 处理器系列产品: 不支持BFLOAT16。

约束说明

  • topkNum <= 512, paddedMode为false时routingMap中每行为1或true的个数固定且小于512

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_token_unpermute_with_routing_map 通过aclnnMoeTokenUnpermuteWithRoutingMap接口方式调用MoeTokenUnpermuteWithRoutingMap算子。