AlltoAllvQuantGroupedMatMul
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成路由专家AlltoAllv、量化GroupedMatMul融合并实现与共享专家量化MatMul并行融合,先通信后计算。
-
计算公式: 假设通信域中的总卡数为epWorldSize,每张卡上通信后路由专家个数为e,每张卡分组矩阵乘只负责本卡专家的计算。对于每张卡的计算公式如下:
-
本卡共享专家分组矩阵乘计算
mm_y=(mm_x × mm_x_scale) @ (mm_weight × mm_weight_scale) -
Alltoallv通信和permute
permute_out=Alltoallv(gmm_x) -
本卡路由专家按专家维度分组矩阵乘计算
gmm_y=(permute_out × gmm_x_scale) @ (gmm_weight × gmm_weight_scale)
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| gmmX | 输入 | 该输入进行AlltoAllv通信后结果作为GroupedMatMul计算的左矩阵。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND |
| gmmWeight | 输入 | GroupedMatMul计算的右矩阵。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND |
| gmmXScale | 输入 | gmmX的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND |
| gmmWeightScale | 输入 | gmmWeight的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND |
| sendCountsTensorOptional | 输入 | 预留参数,当前版本仅支持传nullptr。 | - | - |
| recvCountsTensorOptional | 输入 | 预留参数,当前版本仅支持传nullptr。 | - | - |
| mmXOptional | 输入 | 可选输入,共享专家MatMul计算中的左矩阵。 | 与gmmX保持一致 | ND |
| mmWeightOptional | 输入 | 可选输入,共享专家MatMul计算中的右矩阵。 | 与gmmWeight保持一致 | ND |
| mmXScaleOptional | 输入 | 可选输入,mmX的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND |
| mmWeightScaleOptional | 输入 | 可选输入,mmWeight的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND |
| gmmXQuantMode | 输入 | gmmX的量化模式。 | INT64 | - |
| gmmWeightQuantMode | 输入 | gmmWeight的量化模式。 | INT64 | - |
| mmXQuantMode | 输入 | mmX的量化模式。 | INT64 | - |
| mmWeightQuantMode | 输入 | mmWeight的量化模式。 | INT64 | - |
| group | 输入 | 专家并行的通信域名,字符串长度要求(0, 128)。 | STRING | - |
| epWorldSize | 输入 | ep通信域大小。 | INT64 | - |
| sendCounts | 输入 | 表示发送给其他卡的token数。 | aclIntArray*(元素类型INT64) | - |
| recvCounts | 输入 | 表示接收其他卡的token数。 | aclIntArray*(元素类型INT64) | - |
| transGmmWeight | 输入 | GroupedMatMul的右矩阵是否需要转置。 | BOOL | - |
| transMmWeight | 输入 | 共享专家MatMul的右矩阵是否需要转置。 | BOOL | - |
| groupSize | 输入 | 用于表示量化中gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入的一个数在其所在的对应维度方向上可以用于该方向gmmX/gmmWeight/mmX/mmWeight输入的多少个数的量化。 | INT64 | - |
| permuteOutFlag | 输入 | permuteOutOptional是否需要输出。 | BOOL | - |
| gmmY | 输出 | 路由专家计算的输出。 | FLOAT16、BFLOAT16 | ND |
| mmYOptional | 输出 | 共享专家计算的输出。 | 与gmmY保持一致 | ND |
| permuteOutOptional | 输出 | permute之后的输出。 | 与gmmX保持一致 | ND |
约束说明
-
通信引擎约束:仅支持CCU通信。
-
确定性计算:
aclnnAlltoAllvQuantGroupedMatMul默认确定性实现。
-
参数说明里shape使用的变量:
- BSK:本卡发送的token数,是sendCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,取值范围(0, 32],e * epWorldSize最大支持256。
- N1:表示路由专家的head_num,取值范围(0, 65536)。
- N2:表示共享专家的head_num,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡收到的token数,是recvCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
- mx量化且gmmX与gmmWeight为FLOAT4_E2M1时,H1和H2必须为偶数且不能为2,同时transGmmWeight和transMmWeight为false情况下,N1和N2必须为偶数。
- gmmWeight和gmmWeightScale的转置状态必须保持一致:同时转置或同时不转置。mmWeight和mmWeightScale同样需要保持转置状态一致。
- groupSize:
- 仅当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入都是2维及以上数据时,groupSize取值有效,其他场景需传入0。
- groupSize值支持公式推导:传入的groupSize内部会按如下公式分解得到groupSizeM、groupSizeN、groupSizeK,当其中有1个或多个为0,会根据gmmX/gmmWeight/mmX/mmWeight/gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入shape重新设置groupSizeM、groupSizeN、groupSizeK用于计算。设置原理:如果groupSizeM=0,表示m方向量化分组值由接口推导,推导公式为groupSizeM = m / scaleM(需保证m能被scaleM整除),其中m与gmmX/mmX shape中的m一致,scaleM与gmmXScale/mmXScale shape中的m一致;如果groupSizeK=0,表示k方向量化分组值由接口推导,推导公式为groupSizeK = k / scaleK(需保证k能被scaleK整除),其中k与gmmX/mmX shape中的k一致,scaleK与gmmXScale/mmXScale shape中的k一致;如果groupSizeN=0,表示n方向量化分组值由接口推导,推导公式为groupSizeN = n / scaleN(需保证n能被scaleN整除),其中n与gmmWeight/mmWeight shape中的n一致,scaleN与gmmWeightScale/mmWeightScale shape中的n一致。
groupSize=groupSizeK∣groupSizeN<<16∣groupSizeM<<32groupSize = groupSizeK | groupSizeN << 16 | groupSizeM << 32
- 如果满足重新设置条件,当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入是2维及以上时,且数据类型都为FLOAT8_E8M0时,[groupSizeM,groupSizeN,groupSizeK]取值组合会推导为[1, 1, 32],对应groupSize的值为4295032864。
-
量化参数约束:
- 当前版本支持pertensor量化、mx量化。
-
类型约束
-
pertensor量化
gmmX gmmWeight gmmXScale gmmWeightScale mmXScale mmWeightScale gmmY HIFLOAT8 HIFLOAT8 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT16/BFLOAT16 -
mx量化
gmmX gmmWeight gmmXScale gmmWeightScale mmXScale mmWeightScale gmmY FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E4M3FN FLOAT8_E5M2 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E5M2 FLOAT8_E4M3FN FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E5M2 FLOAT8_E5M2 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT4_E2M1 FLOAT4_E2M1 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 -
mmX类型与gmmX类型保持一致,mmWeight类型与gmmWeight类型保持一致,mmY类型与gmmY类型保持一致,permuteOut类型与gmmX保持一致。
-
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_allto_allv_quant_grouped_mat_mul.cpp | 通过aclnnAlltoAllvQuantGroupedMatMul接口方式调用AlltoAllvQuantGroupedMatMul算子。 |