AlltoAllvGroupedMatMul

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成路由专家AlltoAllv、Permute、GroupedMatMul融合并实现与共享专家MatMul并行融合,先通信后计算

  • 计算公式

    • 路由专家:

    ataOut=AlltoAllv(gmmX)permuteOut=Permute(ataOut)gmmY=permuteOut×gmmWeightataOut = AlltoAllv(gmmX) \\ permuteOut = Permute(ataOut) \\ gmmY = permuteOut \times gmmWeight

    • 共享专家:

    mmY=mmX×mmWeightmmY = mmX \times mmWeight

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
gmmX 输入 该输入进行AlltoAllv通信与Permute操作后结果作为GroupedMatMul计算的左矩阵。 FLOAT16、BFLOAT16 ND
gmmWeight 输入 GroupedMatMul计算的右矩阵。 与gmmX保持一致 ND
sendCountsTensorOptional 输入 预留参数,当前版本仅支持传nullptr。 - -
recvCountsTensorOptional 输入 预留参数,当前版本仅支持传nullptr。 - -
mmXOptional 输入 可选输入,共享专家MatMul计算中的左矩阵。 与gmmX保持一致 ND
mmWeightOptional 输入 可选输入,共享专家MatMul计算中的右矩阵。 与gmmX保持一致 ND
group 输入 专家并行的通信域名,字符串长度要求(0, 128)。 STRING -
epWorldSize 输入 ep通信域的大小。 INT64 -
sendCounts 输入 表示发送给其他卡的token数。 aclIntArray*(元素类型INT64) -
recvCounts 输入 表示接收其他卡的token数。 aclIntArray*(元素类型INT64) -
transGmmWeight 输入 GroupedMatMul的右矩阵是否需要转置。 BOOL -
transMmWeight 输入 共享专家MatMul的右矩阵是否需要转置。 BOOL -
permuteOutFlag 输入 permuteOutOptional是否需要输出。 BOOL -
gmmY 输出 路由专家计算的输出。 与gmmX保持一致 ND
mmYOptional 输出 共享专家计算的输出。 与mmXOptional保持一致 ND
permuteOutOptional 输出 permute之后的输出。 与gmmX保持一致 ND

约束说明

  • 通信引擎约束:

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持AICPU通信。
    • Ascend 950PR/Ascend 950DT:支持CCU通信。
  • 确定性计算:

    • aclnnAlltoAllvGroupedMatMul默认确定性实现。
  • 参数说明里shape使用的变量:

    • BSK:本卡发送的token数,是sendCounts参数累加之和,取值范围(0, 52428800)。
    • H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
    • H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
    • e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
    • N1:表示路由专家的head_num,取值范围(0, 65536)。
    • N2:表示共享专家的head_num,取值范围(0, 65536)。
    • BS:batch sequence size。
    • K:表示选取TopK个专家,K的范围[2, 8]。
    • A:本卡收到的token数,是recvCounts参数累加之和。
    • ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品: 单卡通信量在2MB以下可能存在性能劣化。

调用说明

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
调用方式 样例代码 说明
aclnn接口 test_aclnn_allto_allv_grouped_mat_mul.cpp 通过aclnnAlltoAllvGroupedMatMul接口方式调用allto_allv_grouped_mat_mul算子。
  • Ascend 950PR/Ascend 950DT:
调用方式 样例代码 说明
aclnn接口 test_aclnn_allto_allv_grouped_mat_mul.cpp 通过aclnnAlltoAllvGroupedMatMul接口方式调用allto_allv_grouped_mat_mul算子。