MoeTokenPermuteWithRoutingMap

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

算子功能:MoE的permute计算,将token和expert的标签作为routingMap传入,根据routingMap将tokens和可选probsOptional广播后排序。

计算公式: tokens_num 为routingMap的第0维大小,expert_num为routingMap的第1维大小。

1.dropAndPad为false时:

expertIndex=arrange(tokens_num).expand(expert_num,−1)expertIndex=arrange(tokens\_num).expand(expert\_num,-1)

sortedIndicesFirst=expertIndex.maskedselect(routingMap.T)sortedIndicesFirst=expertIndex.maskedselect(routingMap.T)

sortedIndicesOut=argsort(sortedIndicesFirst)sortedIndicesOut=argsort(sortedIndicesFirst)

topK=numOutTokens//tokens_numtopK = numOutTokens // tokens\_num

outToken=topK∗tokens_numoutToken = topK * tokens\_num

permutedTokensOut[sortedIndicesOut[i]]=tokens[i//topK]permutedTokensOut[sortedIndicesOut[i]]=tokens[i//topK]

permuteProbsOutOptional=probsOptional.T.maskedselect(routingMap.T)permuteProbsOutOptional=probsOptional.T.maskedselect(routingMap.T)

2.dropAndPad为true时:

capacity=numOutTokens//expert_numcapacity = numOutTokens // expert\_num

outToken=capacity∗expert_numoutToken = capacity * expert\_num

sortedIndicesOut=argsort(routingMap.T,dim=−1)[:,:capacity]sortedIndicesOut = argsort(routingMap.T,dim=-1)[:, :capacity]

permutedTokensOut=tokens.indexselect(0,sortedIndicesOut)permutedTokensOut = tokens.index_select(0, sortedIndicesOut)

  • 如果probsOptional不是none时:

    probs_T_1D=probsOptional.T.view(−1)probs\_T\_1D = probsOptional.T.view(-1)

    indices_dim0=arange(expert_num)indices\_dim0 = arange(expert\_num)

    indices_dim1=sortedIndicesOut.view(expert_num,capacity)indices\_dim1 = sortedIndicesOut.view(expert\_num, capacity)

    indices_1D=(indicesdim0∗tokens_num+indices_dim1).view(−1)indices\_1D = (indices_dim0 * tokens\_num + indices\_dim1).view(-1)

    permuteProbsOutOptional=probs_T_1D.indexselect(0,indices1D)permuteProbsOutOptional = probs\_T\_1D.index_select(0, indices_1D)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
tokens 输入 permute中的输入tokens,公式中的tokens。 BFLOAT16、FLOAT16、FLOAT32 ND
routingMap 输入 公式中的routingMap,代表token到expert的映射关系。 INT8、BOOL ND
probsOptional 输入 可选输入,公式中的probsOptional。 BFLOAT16、FLOAT16、FLOAT32 ND
numOutTokens 属性 公式中的numOutTokens,用于计算公式中topK 和capacity 的有效输出token数。 INT64 -
dropAndPad 属性 公式中的dropAndPad,表示是否开启dropAndPad模式。 BOOL -
permutedTokensOut 输出 公式中的permutedTokensOut,根据indices进行扩展并排序筛选过的tokens。 BFLOAT16、FLOAT16、FLOAT32 ND
sortedIndicesOut 输出 公式中的sortedIndicesOut,permute_tokens和tokens的映射关系。 INT32 ND
permuteProbsOutOptional 输出 公式中的permuteProbsOutOptional,根据indices进行排序并筛选过的probsOptional。 BFLOAT16、FLOAT16、FLOAT32 ND

约束说明

  • tokens_num和expert_num要求小于16777215
  • pad模式为false时routingMap中每行为1或true的个数固定且小于512

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_token_permute_with_routing_map 通过aclnnMoeTokenPermuteWithRoutingMap接口方式调用MoeTokenPermuteWithRoutingMap算子。