QuantGroupedMatMulAlltoAllv
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:完成路由专家GroupedMatMul、Unpermute、AlltoAllv融合并实现与共享专家MatMul并行融合,先计算后通信。
-
计算公式:
-
路由专家:
gmmY=(gmmX@gmmWeight)∗gmmXScale∗gmmWeightScaleunpermuteOut=Unpermute(gmmY)y=AlltoAllv(unpermuteOut)gmmY = (gmmX @ gmmWeight) * gmmXScale * gmmWeightScale \\ unpermuteOut = Unpermute(gmmY) \\ y = AlltoAllv(unpermuteOut)
-
共享专家:
mmY=(mmX@mmWeight)∗mmXScaleOptional∗mmWeightScaleOptionalmmY = (mmX @ mmWeight) * mmXScaleOptional * mmWeightScaleOptional
-
参数说明
| 参数名 | 输入/输出 | 描述 | 使用说明 | 数据类型 | 数据格式 | 维度(shape) | 非连续Tensor |
|---|---|---|---|---|---|---|---|
| gmmX | 输入 | 公式中的输入 gmmX。 | shape (A, H1)。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND | 2 | x |
| gmmWeight | 输入 | 公式中的输入 gmmWeight。 | shape (e, H1, N1)。e 为每卡部署的专家数,H1 为 hidden size,N1 为路由专家 FFN 中间维度。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND | 3 | x |
| gmmXScale | 输入 | gmmX 的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND | x | ||
| gmmWeightScale | 输入 | gmmWeight 的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND | x | ||
| sendCountsTensorOptional | 输入 | AlltoAllv 使用的 send count。 | 当前仅支持空。shape (e * ep, )。e 为每卡部署的专家个数,ep 为 ep 域大小。 | INT64 | ND | 1 | x |
| recvCountsTensorOptional | 输入 | AlltoAllv 使用的 recv count。 | 默认为空 Tensor。shape (e * ep, )。e 为每卡部署的专家个数,ep 为 ep 域大小。 | INT64 | ND | 1 | x |
| mmXOptional | 输入 | 公式中的输入 mmX。 | shape (bs, H2)。bs 为每卡部署的专家个数,H2 为 hidden size。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND | 2 | x |
| mmWeightOptional | 输入 | 公式中的输入 mmWeight。 | shape (H2, N2)。H2 为 hidden size,N2 为共享专家 FFN 的中间层维度。 | HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1 | ND | 2 | x |
| mmXScaleOptional | 输入 | mmX 的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND | x | ||
| mmWeightScaleOptional | 输入 | mmWeight 的量化系数。 | FLOAT32、FLOAT8_E8M0 | ND | x | ||
| commQuantScaleOptional | 输入 | 低比特通信量化系数。 | 预留参数,当前仅支持空。 | FLOAT32 | ND | 1 | x |
| gmmXQuantMode | 输入 | gmmX 的量化模式。 | 必须传入量化模式,当前支持 1 (pertensor量化)和 6(mx量化)。 | INT64 | - | 1 | x |
| gmmWeightQuantMode | 输入 | gmmWeight 的量化模式。 | 必须传入量化模式,当前支持 1 (pertensor量化)和 6(mx量化)。 | INT64 | - | 1 | x |
| mmXQuantMode | 输入 | mmX 的量化模式。 | mmX 非空,则必须传入量化模式,当前支持 1 (pertensor量化)和 6(mx量化)。 | INT64 | - | 1 | x |
| mmWeightQuantMode | 输入 | mmWeight 的量化模式。 | mmWeight 不为空,则必须传入量化模式,当前支持 1 (pertensor量化)和 6(mx量化)。 | INT64 | - | 1 | x |
| commQuantMode | 输入 | 低比特通信量化模式。 | 当前低比特功能预留,必须传入 0,表示不量化。 | INT64 | - | 1 | x |
| commQuantDtypeOptional | 输入 | 低比特通信的数据类型。 | 当前低比特功能预留,必须传入 -1。 | INT64 | - | 1 | x |
| groupSize | 输入 | PerGroup 量化分组大小。 | 用于 Matmul 计算三个方向上的量化分组大小,预留参数,仅支持配置为 0,取值不生效。groupSize 输入由 3 个方向的 groupSizeM,groupSizeN,groupSizeK 三个值拼接组成,每个值占 16 位,共占用 int64_t 类型 groupSize 的低 48 位(高 16 位无效),计算公式为:groupSize = groupSizeK | groupSizeN << 16 | groupSizeM << 32。 | INT64 | - | - | - |
| group | 输入 | 通信域标识。 | 字符串长度需大于 0,小于 128。 | char* | - | - | - |
| epWorldSize | 输入 | 通信域大小。 | 支持 2/4/8/16/32/64/128/256。 | INT64 | - | - | - |
| sendCounts | 输入 | AlltoAllv 使用的 send count。表示其他Rank向当前rank上各expert发送的token数量。 | 支持的维度为 e * ep。按sendCounts[fromRank][expertId]一维展开, 例如e=3时顺序为e0,e1,e2,e0,e1,e2, ... |
aclIntArray*(元素类型 INT64) | ND | - | - |
| recvCounts | 输入 | AlltoAllv 使用的 recv count。表示AlltoAllv后本卡需要接收到的token数量。 | 支持的维度为 e * ep。按recvCounts[fromRank][expertId]一维展开, 例如e=3时顺序为e0,e1,e2,e0,e1,e2, ... |
aclIntArray*(元素类型 INT64) | ND | - | - |
| transGmmWeight | 输入 | gmm 的右矩阵是否转置。 | 必须传入,无默认值。 | BOOL | ND | - | - |
| transMmWeight | 输入 | mm 的右矩阵是否转置。 | 必须传入,无默认值。 | BOOL | ND | - | - |
| y | 输出 | grouped matmul 计算输出。 | 不支持空 Tensor。shape (BSK, N1)。 | FLOAT16、BFLOAT16 | ND | 2 | x |
| mmYOptional | 输出 | matmul 计算输出。 | shape (bs, N2)。 | FLOAT16、BFLOAT16 | ND | 2 | x |
| workspaceSize | 输出 | 返回需要在 Device 侧申请的 workspace 大小。 | - | UINT64 | ND | - | - |
| executor | 输出 | 返回 op 执行器,包含了算子计算流程。 | - | aclOpExecutor* | ND | - | - |
gmmXQuantMode、gmmWeightQuantMode、mmXQuantMode、mmWeightQuantMode、commQuantMode的枚举值跟量化模式关系如下:
- 0: 非量化--当前不支持
- 1: pertensor
- 2: perchannel
- 3: pertoken
- 4: pergroup
- 5: perblock
- 6: mx量化
- 7: pertoken动态量化
约束说明
-
确定性计算:
- aclnnQuantGroupedMatMulAlltoAllv默认确定性实现。
-
通信引擎约束:
- Ascend 950PR/Ascend 950DT:支持CCU通信。
-
e * epWorldSize最大支持256,e表示单卡上的专家数量,最大支持到32,epWorldSize支持2/4/8/16/32/64/128/256;
-
gmmX的shape(A, H1),A为sendCounts之和,H1取值范围(0, 65536);
-
gmmWeight的shape(e, H1, N1),N1取值范围(0, 65536);
-
y的shape为(BSK, N1),第一维其中K的范围[2, 8],BSK为recvCounts之和;
-
mmX是共享专家的左矩阵,shape为(BS, H2),H2的取值范围(0, 12288];
-
mmWeight是共享专家的右矩阵,shape为(H2, N2),N2的取值范围(0, 65536);
-
sendCounts为发送到其他卡的token数,数组大小为e * epWorldSize;
-
recvCounts从其他卡的token数,数组大小为e * epWorldSize;
-
路由专家和共享专家量化Scale、Mode等均为必选;
-
低比特通信Mode为必选参数,DType和Scale为可选,当Mode为非0时需要提供DType和Scale;
-
参数说明里shape使用的变量:
- BSK:本卡接收的token数,是recvCounts参数累加之和,取值范围(0, 52428800)。
- H1:表示路由专家hidden size隐藏层大小,取值范围(0, 65536)。
- H2:表示共享专家hidden size隐藏层大小,取值范围(0, 12288]。
- e:表示单卡上专家个数,0<e<=32,e * epWorldSize最大支持256。
- N1:表示路由专家 FFN 的中间层维度,取值范围(0, 65536)。
- N2:表示共享专家 FFN 的中间层维度,取值范围(0, 65536)。
- BS:batch sequence size。
- K:表示选取TopK个专家,K的范围[2, 8]。
- A:本卡发送的token数,是sendCounts参数累加之和。
- ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。
- mx量化且gmmX与gmmWeight为FLOAT4_E2M1时,H1和H2必须为偶数且不能为2,同时transGmmWeight和transMmWeight为false情况下,N1和N2必须为偶数。
- gmmWeight和gmmWeightScale的转置状态必须保持一致:同时转置或同时不转置。mmWeight和mmWeightScale同样需要保持转置状态一致。
- groupSize:
- 仅当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入都是2维及以上数据时,groupSize取值有效,其他场景需传入0。
- groupSize值支持公式推导:传入的groupSize内部会按如下公式分解得到groupSizeM、groupSizeN、groupSizeK,当其中有1个或多个为0,会根据gmmX/gmmWeight/mmX/mmWeight/gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入shape重新设置groupSizeM、groupSizeN、groupSizeK用于计算。设置原理:如果groupSizeM=0,表示m方向量化分组值由接口推导,推导公式为groupSizeM = m / scaleM(需保证m能被scaleM整除),其中m与gmmX/mmX shape中的m一致,scaleM与gmmXScale/mmXScale shape中的m一致;如果groupSizeK=0,表示k方向量化分组值由接口推导,推导公式为groupSizeK = k / scaleK(需保证k能被scaleK整除),其中k与gmmX/mmX shape中的k一致,scaleK与gmmXScale/mmXScale shape中的k一致;如果groupSizeN=0,表示n方向量化分组值由接口推导,推导公式为groupSizeN = n / scaleN(需保证n能被scaleN整除),其中n与gmmWeight/mmWeight shape中的n一致,scaleN与gmmWeightScale/mmWeightScale shape中的n一致。
groupSize=groupSizeK∣groupSizeN<<16∣groupSizeM<<32groupSize = groupSizeK | groupSizeN << 16 | groupSizeM << 32
- 如果满足重新设置条件,当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入是2维及以上时,且数据类型都为FLOAT8_E8M0时,[groupSizeM,groupSizeN,groupSizeK]取值组合会推导为[1, 1, 32],对应groupSize的值为4295032864。
-
量化参数约束:
- 当前版本支持pertensor量化、mx量化。
-
类型约束
-
pertensor量化
gmmX gmmWeight gmmXScale gmmWeightScale mmXScale mmWeightScale y HIFLOAT8 HIFLOAT8 FLOAT32 FLOAT32 FLOAT32 FLOAT32 FLOAT16/BFLOAT16 -
mx量化
gmmX gmmWeight gmmXScale gmmWeightScale mmXScale mmWeightScale y FLOAT8_E4M3FN FLOAT8_E4M3FN FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E4M3FN FLOAT8_E5M2 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E5M2 FLOAT8_E4M3FN FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT8_E5M2 FLOAT8_E5M2 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 FLOAT4_E2M1 FLOAT4_E2M1 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT16/BFLOAT16 -
mmX类型与gmmX类型保持一致,mmWeight类型与gmmWeight类型保持一致,mmY类型与y类型保持一致。
-
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_quant_grouped_mat_mul_allto_allv.cpp | 通过aclnnQuantGroupedMatMulAlltoAllv接口方式调用量化场景的quant_grouped_mat_mul_allto_allv算子。 |