GroupedMatMulAlltoAllv

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成路由专家GroupedMatMul、Unpermute、AlltoAllv融合并实现与共享专家MatMul并行融合,先计算后通信

  • 计算公式:

    • 路由专家:

      gmmY=gmmX×gmmWeightunpermuteOut=Unpermute(gmmY)y=AlltoAllv(unpermuteOut)gmmY = gmmX \times gmmWeight \\ unpermuteOut = Unpermute(gmmY) \\ y = AlltoAllv(unpermuteOut)

    • 共享专家:

      mmY=mmX×mmWeightmmY = mmX \times mmWeight

参数说明

参数名 输入/输出 描述 数据类型 数据格式
gmmX 输入 该输入进行AlltoAllv通信,通信后结果作为GroupedMatMul计算的左矩阵,支持2维,shape为(A, H1)。 FLOAT16、BFLOAT16 ND
gmmWeight 输入 GroupedMatMul计算的右矩阵,数据类型与gmmX保持一致,支持3维,shape为(e, H1, N1)。 FLOAT16、BFLOAT16 ND
sendCountsTensorOptional 输入 可选输入,shape为(e * epWorldSize,),当前版本暂不支持,传nullptr。 INT32、INT64 ND
recvCountsTensorOptional 输入 可选输入,shape为(e * epWorldSize,),当前版本暂不支持,传nullptr。 INT32、INT64 ND
mmXOptional 输入 可选输入,共享专家MatMul计算中的左矩阵,需与mmWeightOptional同时传入或同为nullptr,数据类型与gmmX保持一致,支持2维,shape为(BS, H2)。 FLOAT16、BFLOAT16 ND
mmWeightOptional 输入 可选输入,共享专家MatMul计算中的右矩阵,需与mmXOptional同时传入或同为nullptr,数据类型与gmmX保持一致,支持2维,shape为(H2, N2)。 FLOAT16、BFLOAT16 ND
group 输入 专家并行的通信域,字符串长度要求(0, 128)。 STRING ND
epWorldSize 输入 ep通信域size:
Atlas A3系列产品支持8、16、32、64、128;
Ascend 950PR/Ascend 950DT支持2、4、8、16、32、64。
INT64 ND
sendCounts 输入 表示发送给其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。 aclIntArray*(元素类型INT64) ND
recvCounts 输入 表示接收其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。 aclIntArray*(元素类型INT64) ND
transGmmWeight 输入 gmmWeight是否需要转置,true表示需要转置,false表示不转置。 BOOL ND
transMmWeight 输入 共享专家mmWeightOptional是否需要转置,true表示需要转置,false表示不转置。 BOOL ND
y 输出 最终计算结果,数据类型与输入gmmX保持一致,支持2维,shape为(BSK, N1)。 FLOAT16、BFLOAT16 ND
mmYOptional 输出 共享专家MatMul的输出,数据类型与mmXOptional保持一致,支持2维,shape为(BS, N2),仅当传入mmXOptional与mmWeightOptional才输出。 FLOAT16、BFLOAT16 ND

约束说明

  • 通信引擎约束:

    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持AICPU通信。
    • Ascend 950PR/Ascend 950DT:支持CCU通信。
  • 参数说明里shape使用的变量:

    • BSK:本卡接收的token数,是recvCounts参数累加之和,取值范围(0, 52428800)。
    • H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
    • H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
    • e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
    • N1:表示路由专家的head_num,取值范围(0, 65536)。
    • N2:表示共享专家的head_num,取值范围(0, 65536)。
    • BS:batch sequence size。
    • K:表示选取TopK个专家,K的范围[2, 8]。
    • A:本卡发送的token数,是sendCounts参数累加之和。
    • ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品: 单卡通信量在2MB以下可能存在性能劣化。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_grouped_mat_mul_allto_allv.cpp 通过aclnnGroupedMatMulAlltoAllv接口方式调用grouped_mat_mul_allto_allv算子。