AlltoAllvGroupedMatMul
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成路由专家AlltoAllv、Permute、GroupedMatMul融合并实现与共享专家MatMul并行融合,先通信后计算。
-
计算公式:
- 路由专家:
ataOut=AlltoAllv(gmmX)permuteOut=Permute(ataOut)gmmY=permuteOut×gmmWeightataOut = AlltoAllv(gmmX) \\ permuteOut = Permute(ataOut) \\ gmmY = permuteOut \times gmmWeight
- 共享专家:
mmY=mmX×mmWeightmmY = mmX \times mmWeight
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| gmmX | 输入 | 该输入进行AlltoAllv通信与Permute操作后结果作为GroupedMatMul计算的左矩阵。 | FLOAT16、BFLOAT16 | ND |
| gmmWeight | 输入 | GroupedMatMul计算的右矩阵。 | 与gmmX保持一致 | ND |
| sendCountsTensorOptional | 输入 | 预留参数,当前版本仅支持传nullptr。 | - | - |
| recvCountsTensorOptional | 输入 | 预留参数,当前版本仅支持传nullptr。 | - | - |
| mmXOptional | 输入 | 可选输入,共享专家MatMul计算中的左矩阵。 | 与gmmX保持一致 | ND |
| mmWeightOptional | 输入 | 可选输入,共享专家MatMul计算中的右矩阵。 | 与gmmX保持一致 | ND |
| group | 输入 | 专家并行的通信域名,字符串长度要求(0, 128)。 | STRING | - |
| epWorldSize | 输入 | ep通信域的大小。 | INT64 | - |
| sendCounts | 输入 | 表示发送给其他卡的token数。 | aclIntArray*(元素类型INT64) | - |
| recvCounts | 输入 | 表示接收其他卡的token数。 | aclIntArray*(元素类型INT64) | - |
| transGmmWeight | 输入 | GroupedMatMul的右矩阵是否需要转置。 | BOOL | - |
| transMmWeight | 输入 | 共享专家MatMul的右矩阵是否需要转置。 | BOOL | - |
| permuteOutFlag | 输入 | permuteOutOptional是否需要输出。 | BOOL | - |
| gmmY | 输出 | 路由专家计算的输出。 | 与gmmX保持一致 | ND |
| mmYOptional | 输出 | 共享专家计算的输出。 | 与mmXOptional保持一致 | ND |
| permuteOutOptional | 输出 | permute之后的输出。 | 与gmmX保持一致 | ND |
约束说明
-
通信引擎约束:
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持AICPU通信。
- Ascend 950PR/Ascend 950DT:支持CCU通信。
-
确定性计算:
- aclnnAlltoAllvGroupedMatMul默认确定性实现。
-
参数说明里shape使用的变量:
- BSK:本卡发送的token数,是sendCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,e<=32,e * epWorldSize最大支持256。
- N1:表示路由专家的head_num,取值范围(0, 65536)。
- N2:表示共享专家的head_num,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡收到的token数,是recvCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品: 单卡通信量在2MB以下可能存在性能劣化。
调用说明
- Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_allto_allv_grouped_mat_mul.cpp。 | 通过aclnnAlltoAllvGroupedMatMul接口方式调用allto_allv_grouped_mat_mul算子。 |
- Ascend 950PR/Ascend 950DT:
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_allto_allv_grouped_mat_mul.cpp。 | 通过aclnnAlltoAllvGroupedMatMul接口方式调用allto_allv_grouped_mat_mul算子。 |