GroupedMatMulAlltoAllv
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成路由专家GroupedMatMul、Unpermute、AlltoAllv融合并实现与共享专家MatMul并行融合,先计算后通信。
-
计算公式:
-
路由专家:
gmmY=gmmX×gmmWeightunpermuteOut=Unpermute(gmmY)y=AlltoAllv(unpermuteOut)gmmY = gmmX \times gmmWeight \\ unpermuteOut = Unpermute(gmmY) \\ y = AlltoAllv(unpermuteOut)
-
共享专家:
mmY=mmX×mmWeightmmY = mmX \times mmWeight
-
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| gmmX | 输入 | 该输入进行AlltoAllv通信,通信后结果作为GroupedMatMul计算的左矩阵,支持2维,shape为(A, H1)。 | FLOAT16、BFLOAT16 | ND |
| gmmWeight | 输入 | GroupedMatMul计算的右矩阵,数据类型与gmmX保持一致,支持3维,shape为(e, H1, N1)。 | FLOAT16、BFLOAT16 | ND |
| sendCountsTensorOptional | 输入 | 可选输入,shape为(e * epWorldSize,),当前版本暂不支持,传nullptr。 | INT32、INT64 | ND |
| recvCountsTensorOptional | 输入 | 可选输入,shape为(e * epWorldSize,),当前版本暂不支持,传nullptr。 | INT32、INT64 | ND |
| mmXOptional | 输入 | 可选输入,共享专家MatMul计算中的左矩阵,需与mmWeightOptional同时传入或同为nullptr,数据类型与gmmX保持一致,支持2维,shape为(BS, H2)。 | FLOAT16、BFLOAT16 | ND |
| mmWeightOptional | 输入 | 可选输入,共享专家MatMul计算中的右矩阵,需与mmXOptional同时传入或同为nullptr,数据类型与gmmX保持一致,支持2维,shape为(H2, N2)。 | FLOAT16、BFLOAT16 | ND |
| group | 输入 | 专家并行的通信域,字符串长度要求(0, 128)。 | STRING | ND |
| epWorldSize | 输入 | ep通信域size: Atlas A3系列产品支持8、16、32、64、128; Ascend 950PR/Ascend 950DT支持2、4、8、16、32、64。 |
INT64 | ND |
| sendCounts | 输入 | 表示发送给其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。 | aclIntArray*(元素类型INT64) | ND |
| recvCounts | 输入 | 表示接收其他卡的token数,数据类型支持INT64,取值大小为e * epWorldSize,最大为256。 | aclIntArray*(元素类型INT64) | ND |
| transGmmWeight | 输入 | gmmWeight是否需要转置,true表示需要转置,false表示不转置。 | BOOL | ND |
| transMmWeight | 输入 | 共享专家mmWeightOptional是否需要转置,true表示需要转置,false表示不转置。 | BOOL | ND |
| y | 输出 | 最终计算结果,数据类型与输入gmmX保持一致,支持2维,shape为(BSK, N1)。 | FLOAT16、BFLOAT16 | ND |
| mmYOptional | 输出 | 共享专家MatMul的输出,数据类型与mmXOptional保持一致,支持2维,shape为(BS, N2),仅当传入mmXOptional与mmWeightOptional才输出。 | FLOAT16、BFLOAT16 | ND |
约束说明
-
通信引擎约束:
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持AICPU通信。
- Ascend 950PR/Ascend 950DT:支持CCU通信。
-
参数说明里shape使用的变量:
- BSK:本卡接收的token数,是recvCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
- N1:表示路由专家的head_num,取值范围(0, 65536)。
- N2:表示共享专家的head_num,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡发送的token数,是sendCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品: 单卡通信量在2MB以下可能存在性能劣化。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_grouped_mat_mul_allto_allv.cpp | 通过aclnnGroupedMatMulAlltoAllv接口方式调用grouped_mat_mul_allto_allv算子。 |