GroupedMatmulSwigluQuantV2
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
- 算子功能:融合GroupedMatmul 、dequant、swiglu和quant,详细解释见计算公式。
- 计算公式:
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
量化场景A8W8(A指激活矩阵,W指权重矩阵,8指INT8数据类型):
-
定义:
- ⋅ 表示矩阵乘法。
- ⊙ 表示逐元素乘法。
- ⌊x⌉\left \lfloor x\right \rceil 表示将x四舍五入到最近的整数。
- Z8={x∈Z∣−128≤x≤127}\mathbb{Z_8} = \{ x \in \mathbb{Z} | −128≤x≤127 \}
- Z32={x∈Z∣−2147483648≤x≤2147483647}\mathbb{Z_{32}} = \{ x \in \mathbb{Z} | -2147483648≤x≤2147483647 \}
-
输入:
- X∈Z8M×KX∈\mathbb{Z_8}^{M \times K}:激活矩阵(左矩阵),M是总token数,K是特征维度。
- W∈Z8E×K×NW∈\mathbb{Z_8}^{E \times K \times N}:分组权重矩阵(右矩阵),E是专家个数,K是特征维度,N是输出维度。
- w_scale∈RE×Nw\_scale∈\mathbb{R}^{E \times N}:分组权重矩阵(右矩阵)的逐通道缩放因子,E是专家个数,N是输出维度。
- x_scale∈RMx\_scale∈\mathbb{R}^{M}:激活矩阵(左矩阵)的逐 token缩放因子,M是总token数。
- grouplist∈NEgrouplist∈\mathbb{N}^{E}:cumsum或count的分组索引列表。
-
输出:
- Q∈Z8M×N/2Q∈\mathbb{Z_8}^{M \times N / 2}:量化后的输出矩阵。
- Q_scale∈RMQ\_scale∈\mathbb{R}^{M}:量化缩放因子。
-
计算过程
-
1.根据groupList[i]确定当前分组的 token ,i∈[0,Len(groupList)]i \in [0,Len(groupList)]。
例子:假设groupList=[3,4,4,6]、groupListType=cumsum或groupList=[3,1,0,2]、groupListType=count。
注:以上两种不同的分组方式,实际为相同的分组结果。
第0个右矩阵
W[0,:,:],对应索引位置[0,3)的tokenx[0:3](共3-0=3个token),对应x_scale[0:3]、w_scale[0]、bias[0]、offset[0]、Q[0:3]、Q_scale[0:3]、Q_offset[0:3];第1个右矩阵
W[1,:,:],对应索引位置[3,4)的tokenx[3:4](共4-3=1个token),对应x_scale[3:4]、w_scale[1]、bias[1]、offset[1]、Q[3:4]、Q_scale[3:4]、Q_offset[3:4];第2个右矩阵
W[2,:,:],对应索引位置[4,4)的tokenx[4:4](共4-4=0个token),对应x_scale[4:4]、w_scale[2]、bias[2]、offset[2]、Q[4:4]、Q_scale[4:4]、Q_offset[4:4];第3个右矩阵
W[3,:,:],对应索引位置[4,6)的tokenx[4:6](共6-4=2个token),对应x_scale[4:6]、w_scale[3]、bias[3]、offset[3]、Q[4:6]、Q_scale[4:6]、Q_offset[4:6];注:grouplist中未指定的部分将不会参与更新。 例如当groupList=[12,14,18]、GroupListType=cumsum,X的shape为[30,:]时。
则第一个输出Q的shape为[30,:],其中Q[18:,:]的部分不会进行更新和初始化,其中数据为显存空间申请时的原数据。
同理,第二个输出Q的shape为[30],其中Q_scale[18:]的部分不会进行更新或初始化,其中数据为显存空间申请时的原数据。
即输出的Q[:grouplist[-1],:]和Q_scale[:grouplist[-1]]为有效数据部分。
-
2.根据分组确定的入参进行如下计算:
Ci=(Xi⋅Wi)⊙x_scalei BroadCast⊙w_scalei BroadCastC_{i} = (X_{i}\cdot W_{i} )\odot x\_scale_{i\ BroadCast} \odot w\_scale_{i\ BroadCast}
Ci,act,gatei=split(Ci)C_{i,act}, gate_{i} = split(C_{i})
Si=Swish(Ci,act)⊙gateiS_{i}=Swish(C_{i,act})\odot gate_{i} 其中Swish(x)=x1+e−xSwish(x)=\frac{x}{1+e^{-x}}
-
3.量化输出结果
Q_scalei=max(∣Si∣)127Q\_scale_{i} = \frac{max(|S_{i}|)}{127}
Qi=⌊SiQ_scalei⌉Q_{i} = \left\lfloor \frac{S_{i}}{Q\_scale_{i}} \right\rceil
-
MSD场景A8W4(A指激活矩阵,W指权重矩阵,8指INT8数据类型,4指INT4数据类型):
- 定义:
- ⋅ 表示矩阵乘法。
- ⊙ 表示逐元素乘法。
- ⌊x⌉\left \lfloor x\right \rceil 表示将x四舍五入到最近的整数。
- Z8={x∈Z∣−128≤x≤127}\mathbb{Z_8} = \{ x \in \mathbb{Z} | −128≤x≤127 \}
- Z4={x∈Z∣−8≤x≤7}\mathbb{Z_4} = \{ x \in \mathbb{Z} | −8≤x≤7 \}
- Z32={x∈Z∣−2147483648≤x≤2147483647}\mathbb{Z_{32}} = \{ x \in \mathbb{Z} | -2147483648≤x≤2147483647 \}
- 输入:
- X∈Z8M×KX∈\mathbb{Z_8}^{M \times K}:激活矩阵(左矩阵),M是总token数,K是特征维度。
- W∈Z4E×K×NW∈\mathbb{Z_4}^{E \times K \times N}:分组权重矩阵(右矩阵),E是专家个数,K是特征维度,N是输出维度。
- weightAsistMatrix∈RE×NweightAsistMatrix∈\mathbb{R}^{E \times N}:计算矩阵乘时的辅助矩阵(生成辅助矩阵的计算过程见下文)。
- w_scale∈RE×K_group_num×Nw\_scale∈\mathbb{R}^{E \times K\_group\_num \times N}:分组权重矩阵(右矩阵)的逐通道缩放因子,E是专家个数,K_group_num 是在K轴维 度上的分组数,N是输出维度。
- x_scale∈RMx\_scale∈\mathbb{R}^{M}:激活矩阵(左矩阵)的逐token缩放因子,M是总token数。
- grouplist∈NEgrouplist∈\mathbb{N}^{E}:cumsum或count的分组索引列表。
- 输出:
- Q∈Z8M×N/2Q∈\mathbb{Z_8}^{M \times N / 2}:量化后的输出矩阵。
- Q_scale∈RMQ\_scale∈\mathbb{R}^{M}:量化缩放因子。
- 计算过程
-
1.根据groupList[i]确定当前分组的token,i∈[0,Len(groupList)]i \in [0,Len(groupList)]。
- 分组逻辑与A8W8相同。
-
2.生成辅助矩阵(weightAsistMatrix)的计算过程(请注意weightAsistMatrix部分计算为离线生成作为输入,并非算子内部完成):
-
当为per-channel量化(w_scalew\_scale为2维):
weightAsistMatrixi=8×weightScale×Σk=0K−1weight[:,k,:]weightAsistMatrix_{i} = 8 × weightScale × Σ_{k=0}^{K-1} weight[:,k,:]
-
当为per-group量化(w_scalew\_scale为3维):
weightAsistMatrixi=8×Σk=0K−1(weight[:,k,:]×weightScale[:,⌊k/num_per_group⌋,:])weightAsistMatrix_{i} = 8 × Σ_{k=0}^{K-1} (weight[:,k,:] × weightScale[:, ⌊k/num\_per\_group⌋, :])
注:num_per_group=K//K_group_numnum\_per\_group = K // K\_group\_num
-
-
3.根据分组确定的入参进行如下计算:
-
3.1.将左矩阵Z8\mathbb{Z_8},转变为高低位 两部分的Z4\mathbb{Z_4} X_high_4bitsi=⌊Xi16⌋X\_high\_4bits_{i} = \lfloor \frac{X_{i}}{16} \rfloor X_low_4bitsi=Xi&0x0f−8X\_low\_4bits_{i} = X_{i} \& 0x0f - 8
-
3.2.做矩阵乘时,使能per-channel或per-group量化 per-channel:
C_highi=(X_high_4bitsi⋅Wi)⊙w_scaleiC\_high_{i} = (X\_high\_4bits_{i} \cdot W_{i}) \odot w\_scale_{i}
C_lowi=(X_low_4bitsi⋅Wi)⊙w_scaleiC\_low_{i} = (X\_low\_4bits_{i} \cdot W_{i}) \odot w\_scale_{i}
per-group:
C_highi=Σk=0K−1((X_high_4bitsi[:,k∗num_per_group:(k+1)∗num_per_group]⋅Wi[k∗num_per_group:(k+1)∗num_per_group,:])⊙w_scalei[k,:])C\_high_{i} = \\ Σ_{k=0}^{K-1}((X\_high\_4bits_{i}[:, k * num\_per\_group : (k+1) * num\_per\_group] \cdot W_{i}[k * num\_per\_group : (k+1) * num\_per\_group, :]) \odot w\_scale_{i}[k, :] )
C_lowi=Σk=0K−1((X_low_4bitsi[:,k∗num_per_group:(k+1)∗num_per_group]⋅Wi[k∗num_per_group:(k+1)∗num_per_group,:])⊙w_scalei[k,:])C\_low_{i} = \\ Σ_{k=0}^{K-1}((X\_low\_4bits_{i}[:, k * num\_per\_group : (k+1) * num\_per\_group] \cdot W_{i}[k * num\_per\_group : (k+1) * num\_per\_group, :]) \odot w\_scale_{i}[k, :] )
-
3.3.将高低位的矩阵乘结果还原为整体的结果
Ci=(C_highi∗16+C_lowi+weightAsistMatrixi)⊙x_scaleiC_{i} = (C\_high_{i} * 16 + C\_low_{i} + weightAsistMatrix_{i}) \odot x\_scale_{i}
Ci,act,gatei=split(Ci)C_{i,act}, gate_{i} = split(C_{i})
Si=Swish(Ci,act)⊙gateiS_{i}=Swish(C_{i,act})\odot gate_{i} 其中Swish(x)=x1+e−xSwish(x)=\frac{x}{1+e^{-x}}
-
-
3.量化输出结果
Q_scalei=max(∣Si∣)127Q\_scale_{i} = \frac{max(|S_{i}|)}{127}
Qi=⌊SiQ_scalei⌉Q_{i} = \left\lfloor \frac{S_{i}}{Q\_scale_{i}} \right\rceil
-
-
-
Ascend 950PR/Ascend 950DT:
MX量化场景:
-
定义:
- ⋅ 表示矩阵乘法。
- ⊙ 表示逐元素乘法。
-
计算过程
-
1.根据groupList[i]确定当前分组的 token ,i∈[0,Len(groupList)]i \in [0,Len(groupList)]
-
2.根据分组确定的入参进行如下计算:
Ci=(Xi⋅Wi)⊙xScalei BroadCast⊙wScalei BroadCastC_{i} = (X_{i}\cdot W_{i} )\odot xScale_{i\ BroadCast} \odot wScale_{i\ BroadCast}
Ci,act,gatei=split(Ci)C_{i,act}, gate_{i} = split(C_{i})
Si=Swish(Ci,act)⊙gateiS_{i}=Swish(C_{i,act})\odot gate_{i},其中Swish(x)=x1+e−xSwish(x)=\frac{x}{1+e^{-x}}
-
3.量化输出结果
shared_exp=⌊log2(maxi(∣Si∣))⌉−emaxshared\_exp = \left\lfloor \log_2(max_i(|S_i|)) \right\rceil - emax
QScale=2shared_expQScale = 2 ^ {shared\_exp}
Qi=quantize_to_element_format(Si/Qscale), i from 1 to blocksizeQ_i = quantize\_to\_element\_format(S_i/Qscale), \space i\space from\space 1\space to\space blocksize
-
emaxemax: 对应数据类型的最大正则数的指数位。
DataType emax FLOAT8_E4M3FN 8 FLOAT8_E5M2 15 FLOAT4_E2M1 2 -
blocksizeblocksize:指每次量化的元素个数,仅支持32。
-
-
-
-
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 表示左矩阵,对应公式中的X。 | FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1、INT8 | ND |
| weight | 输入 | 表示权重矩阵,对应公式中的W。 | FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1、INT8、INT4、INT32 | ND、FRACTAL_NZ |
| weightScale | 输入 | 表示右矩阵的量化因子,公式中的wScale。 | FLOAT8_E8M0、UINT64、FLOAT、FLOAT16、BFLOAT16 | ND |
| weightAssistMatrix | 可选输入 | 表示计算矩阵乘时的辅助矩阵,公式中的weightAssistMatrix。 | FLOAT | ND |
| bias | 可选输入 | 表示矩阵乘计算的偏移值。 | - | - |
| xScale | 输入 | 表示左矩阵的量化因子,公式中的xScale。 | FLOAT8_E8M0、FLOAT | ND |
| smoothScale | 可选输入 | 表示左矩阵的量化因子。 | - | - |
| groupList | 输入 | 表示每个分组参与计算的Token个数,公式中的grouplist。 | INT64 | ND |
| dequantMode | 输入 | 表示反量化计算类型,用于确定激活矩阵与权重矩阵的反量化方式。 | INT64 | - |
| dequantDtype | 输入 | 表示中间GroupedMatmul的结果数据类型。 | INT64 | - |
| quantMode | 输入 | 表示量化计算类型,用于确定swiglu结果的量化模式。 | INT64 | - |
| groupListType | 输入 | 表示分组的解释方式,用于确定groupList的语义。 | INT64 | - |
| tuningConfig | 输入 | 用于算子预估M/E的大小,走不同的算子模板,以适配不不同场景性能要求。 | - | - |
| output | 输出 | 表示输出的量化结果,公式中的Q。 | FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1、INT8 | ND |
| outputScale | 输出 | 表示输出的量化因子,公式中的QScale。 | FLOAT8_E8M0、FLOAT | ND |
| workspaceSize | 输出 | 返回需要在Device侧申请的workspace大小。 | - | - |
| executor | 输出 | 返回op执行器,包含了算子计算流程。 | - | - |
- Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
- x仅支持INT8量化数据类型、不支持其他数据类型。
- weight仅支持非转置,支持INT8、INT4、INT32数据类型,ND格式shape形如{(E, K, N)},NZ格式下,当weight数据类型是INT8时shape形如{(E, N / 32, K / 16, 16, 32)},INT4时shape形如{(E, N / 64, K / 16, 16, 64)},INT32时shape形如{(E, N / 64, K / 16, 16, 8)}。
- weightScale,A8W8场景支持FLOAT、FLOAT16、BFLOAT16数据类型,shape只支持2维,形如{(E, N)};A8W4场景支持UINT64数据类型,shape支持2维和3维,其中per-channel的shape形如{(E, N)},per-group的shape形如{(E, KGroupCount, N)}。
- 支持dequantMode参数:0表示激活矩阵per-token,权重矩阵per-channel;1表示激活矩阵per-token,权重矩阵per-group。
- 不支持dequantDtype参数。
- 不支持quantMode参数。
- A8W8/A8W4场景,不支持N轴长度超过10240。
- A8W8场景,不支持x的尾轴长度大于等于65536。
- A8W4场景,不支持x的尾轴长度大于等于20000。
- output仅支持数据类型INT8,shape支持2维,形如(M, N / 2)。
- outputScale仅支持数据类型FLOAT,shape支持1维,形如(M,)。
- Ascend 950PR/Ascend 950DT:
- 仅支持FLOAT8、FLOAT4量化数据类型,不支持其他数据类型,支持weight转置。
- x支持FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1数据类型。
- weight支持FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1数据类型,非转置shape形如{(E, K, N)},weight转置shape形如{(E, N, K)}。
- weightScale支持FLOAT8_E8M0数据类型,shape支持4维:weightScale非转置shape形如{(E, ceil(K / 64), N, 2)},weightScale转置shape形如{(E, N, ceil(K / 64), 2)}。
- xScale: FLOAT8_E8M0数据类型:shape支持3维,形如(M, ceil(K / 64), 2)。
- 支持dequantMode参数:当前仅支持取值2,2表示MX量化。
- 支持dequantDtype参数:当前仅支持取值0,0表示DT_FLOAT。
- 支持quantMode参数:当前仅支持取值2,2表示MX量化。
- output支持数据类型FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1,shape支持2维,形如(M, N / 2)。
- outputScale支持数据类型FLOAT8_E8M0,shape支持3维,形如(M, ceil((N / 2) / 64), 2)。
约束说明
-
Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:
-
A8W8/A8W4量化场景下需满足以下约束条件:
- 数据类型需要满足下表:
量化场景 x weight weightScale xScale output outputScale A8W8 INT8 INT8 FLOAT、FLOAT16、BFLOAT16 FLOAT INT8 FLOAT A8W4 INT8 INT4、INT32 UINT64 FLOAT INT8 FLOAT - A8W8场景下,不支持N轴长度超过10240,不支持x的尾轴长度大于等于65536。
- A8W4场景下,不支持N轴长度超过10240,不支持x的尾轴长度大于等于20000。
-
-
Ascend 950PR/Ascend 950DT:
-
MX量化场景下需满足以下约束条件:
- 数据类型需要满足下表:
MX量化场景 x weight weightScale xScale output outputScale MXFP8 FLOAT8_E4M3FN、FLOAT8_E5M2 FLOAT8_E4M3FN、FLOAT8_E5M2 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT8_E4M3FN、FLOAT8_E5M2 FLOAT8_E8M0 MXFP4 FLOAT4_E2M1 FLOAT4_E2M1 FLOAT8_E8M0 FLOAT8_E8M0 FLOAT4_E2M1、FLOAT8_E4M3FN、FLOAT8_E5M2 FLOAT8_E8M0 - MX量化场景下,需满足N为128对齐。
- MXFP4场景不支持K=2。
- MXFP4场景需满足K为偶数;当output的数据类型为FLOAT4_E2M1时,需满足N为大于等于4的偶数。
-
-
确定性计算:
- aclnnGroupedMatmulSwigluQuantV2默认为确定性实现。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_quant_grouped_matmul_swiglu_quant_V2 | 通过接口方式调用GroupedMatmulSwigluQuantV2算子。 |