MoeFusedTopk

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:MoE计算中,对输入x做Sigmoid计算,对计算结果分组进行排序,最后根据分组排序的结果选取前k个专家。

  • 计算公式

    对输入做sigmoid:

    sigmoidRes=sigmoid(x)sigmoidRes=sigmoid(x)

    加上addNum:

    normOut=sigmoidRes+addNumnormOut = sigmoidRes + addNum

    对计算结果按照groupNum进行分组,每组按照topN的sum值对group进行排序,取前groupTopk个组:

    groupOut,groupId=TopK(ReduceSum(TopK(Split(normOut,groupCount),k=2,dim=−1),dim=−1),k=kGroup)groupOut, groupId = TopK(ReduceSum(TopK(Split(normOut, groupCount), k=2, dim=-1), dim=-1),k=kGroup)

    根据上一步的groupId获取normOut中对应的元素,将数据再做TopK,得到indices的结果:

    normY,indices=TopK(normOut[groupId,:],k=k)normY,indices=TopK(normOut[groupId, :],k=k)

    根据indices从sigmoidRes中选出y:

    y=gather(sigmoidRes,indices)y = gather(sigmoidRes, indices)

    如果isNorm为true,对y按照输入的scale参数进行计算,得到y的结果:

    y=y/(ReduceSum(y,dim=−1))∗scaley = y / (ReduceSum(y, dim=-1))*scale

    如果enableExpertMapping为true,再将indices中的物理专家按照输入的mappingNum和mappingTable映射到逻辑专家,得到输出的indices。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 每个token对应各个专家的分数,对应公式中的`x`。 FLOAT16、BFLOAT16、FLOAT32 ND
addNum 输入 与输入x进行计算的偏置值,对应公式中的`addNum`。 FLOAT16、BFLOAT16、FLOAT32 ND
mappingNum 输入 `enableExpertMapping`为false时不启用,每个物理专家被实际映射到的逻辑专家数量。 INT32 ND
mappingTable 输入 `enableExpertMapping`为false时不启用,每个物理专家/逻辑专家映射表。 INT32 ND
groupNum 属性 分组数量。 UINT32 -
groupTopk 属性 被选择的组的数量。 UINT32 -
topN 属性 组内选取的用于求和的专家数量。 UINT32 -
topK 属性 最终选取的专家数量。 UINT32 -
activateType 属性 激活类型,当前只支持0(ACTIVATION_SIGMOID)。 UINT32 -
isNorm 属性 是否对输出进行归一化。 BOOL -
scale 属性 归一化后的系数乘。 FLOAT -
enableExpertMapping 属性 是否使能物理专家到逻辑专家的映射。 BOOL -
y 输出 输出每个token的topK分数,对应公式中的`y`。 FLOAT32 ND
indices 输出 topK个专家和tokens的映射关系,对应公式中的`indices`。 INT32 ND

约束说明

  • x和addNum数据类型必须一致。
  • expertNum必须为groupNum的整数倍。
  • groupTopk小于等于groupNum。
  • maxMappingNum小于等于128。
  • TopK小于等于expertNum。
  • TopN小于等于expertNum / groupNum。
  • expertNum小于等于1024。
  • groupNum小于等于256。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_moe_fused_topk 通过aclnnMoeFusedTopk接口方式调用MoeFusedTopk算子。