DynamicMxQuantWithDualAxis
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:在-1轴和-2轴上同时进行目的数据类型为FLOAT4类、FLOAT8类的MX量化。在给定的-1轴和-2轴上,每32个数,计算出这两组数对应的量化尺度mxscale1、mxscale2,然后分别对两组数所有元素除以对应的mxscale1或mxscale2,根据round_mode转换到对应的dst_type,得到量化结果y1和y2。在dst_type为FLOAT8_E4M3FN、FLOAT8_E5M2时,根据scale_alg的取值来指定计算mxscale的不同算法。
-
合轴说明:算子实现时,会对-2轴(不包含)之前的所有轴进行合轴处理。即对于输入shape为(d0,d1,...,dn−3,dn−2,dn−1)(d_0, d_1, ..., d_{n-3}, d_{n-2}, d_{n-1})的张量,-2轴之前的维度(d0,d1,...,dn−3)(d_0, d_1, ..., d_{n-3})会被合并为一个维度,等效于将输入reshape为(d0×d1×...×dn−3,dn−2,dn−1)(d_0 \times d_1 \times ... \times d_{n-3}, d_{n-2}, d_{n-1})后再进行量化计算。
-
计算公式:
-
场景1,当scale_alg为0时,即OCP Microscaling Formats (Mx) Specification实现:
-
将输入x在-1轴上按照32个数进行分组,一组32个数 {{Vi}i=132}\{\{V_i\}_{i=1}^{32}\} 量化为 {mxscale1,{Pi}i=132}\{mxscale1, \{P_i\}_{i=1}^{32}\}
shared_exp=floor(log2(maxi(∣Vi∣)))−emaxshared\_exp = floor(log_2(max_i(|V_i|))) - emax
mxscale1=2shared_expmxscale1 = 2^{shared\_exp}
Pi=cast_to_dst_type(Vi/mxscale1,round_mode), i from 1 to 32P_i = cast\_to\_dst\_type(V_i/mxscale1, round\_mode), \space i\space from\space 1\space to\space 32
-
同时,将输入x在-2轴上按照32个数进行分组,一组32个数 {{Vj}j=132}\{\{V_j\}_{j=1}^{32}\} 量化为 {mxscale2,{Pj}j=132}\{mxscale2, \{P_j\}_{j=1}^{32}\}
shared_exp=floor(log2(maxj(∣Vj∣)))−emaxshared\_exp = floor(log_2(max_j(|V_j|))) - emax
mxscale2=2shared_expmxscale2 = 2^{shared\_exp}
Pj=cast_to_dst_type(Vj/mxscale2,round_mode), j from 1 to 32P_j = cast\_to\_dst\_type(V_j/mxscale2, round\_mode), \space j\space from\space 1\space to\space 32
-
-1轴量化后的 PiP_{i} 按对应的 ViV_{i} 的位置组成输出y1,mxscale1按对应的-1轴维度上的分组组成输出mxscale1。-2轴量化后的 PjP_{j} 按对应的 VjV_{j} 的位置组成输出y2,mxscale2按对应的-2轴维度上的分组组成输出mxscale2。
-
emax: 对应数据类型的最大正则数的指数位。
DataType emax FLOAT4_E2M1 2 FLOAT4_E1M2 0 FLOAT8_E4M3FN 8 FLOAT8_E5M2 15 -
场景2,当scale_alg为1时,只涉及FP8类型(CuBALS Scale计算算法):
- 将输入x在-1轴(或-2轴)上按照32个数进行分组,对每组单独计算块缩放因子Sfp32bS_{fp32}^b,再把组内所有元素映射到目标低精度类型FP8。
- 找到该组中数值的最大绝对值:Amax(Dfp32b)=max({∣di∣}i=132)Amax(D_{fp32}^b) = max(\{|d_i|\}_{i=1}^{32})
- 将FP32映射到目标数据类型FP8可表示的范围内:Sfp32b=Amax(Dfp32b)Amax(DType)S_{fp32}^b = \frac{Amax(D_{fp32}^b)}{Amax(DType)}
- 从块缩放因子中提取无偏指数EintbE_{int}^b和尾数MfixpbM_{fixp}^b,并进行条件向上取整。
- 计算块缩放因子:Sue8m0b=2EintbS_{ue8m0}^b=2^{E_{int}^b},计算块转换因子:Rfp32b=1fp32(Sue8m0b)R_{fp32}^b=\frac{1}{fp32(S_{ue8m0}^b)}
- 对每个组内元素应用量化:di=DType(dfp32i⋅Rfp32b)d^i = DType(d_{fp32}^i \cdot R_{fp32}^b)
-
场景3,当scale_alg为2时,只涉及FP4_E2M1类型(DynamicDtypeRange算法,仅V2接口支持):
- 当dstTypeMax = 0.0/6.0/7.0时,使用指数域addValueBit进位法计算scale。
- 当dstTypeMax为其他自定义值时,使用FP32精度invDstTypeMax乘法法计算scale。
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 待量化数据,对应公式中Vi和di。 目的类型为FLOAT4_E2M1、FLOAT4_E1M2时,x的最后一维必须是偶数。 |
FLOAT16、BFLOAT16 | ND |
| round_mode | 可选属性 | 数据转换的模式。 当dst_type为40/41(FLOAT4_E2M1/FLOAT4_E1M2)时,支持{"rint", "floor", "round"}; 当dst_type为35/36(FLOAT8_E5M2/FLOAT8_E4M3FN)时,仅支持{"rint"}; 传入空指针时,采用"rint"模式。 |
STRING | - |
| dst_type | 可选属性 | 指定数据转换后y1和y2的类型。 输入范围为{35, 36, 40, 41},分别对应{35:FLOAT8_E5M2, 36:FLOAT8_E4M3FN, 40:FLOAT4_E2M1, 41:FLOAT4_E1M2}。 |
INT64 | - |
| scale_alg | 可选属性 | mxscale1和mxscale2的计算方法。 支持取值0、1和2,取值为0代表OCP实现(场景1),为1代表CuBALS实现(场景2),为2代表DynamicDtypeRange实现(场景3)。 当dst_type为FLOAT4_E1M2时仅支持取值为0。 当dst_type为FLOAT4_E2M1时仅支持取值为0和2。 当dst_type为FLOAT8时仅支持取值为0和1。 |
INT64 | - |
| dst_type_max | 可选属性 | maxType的取值,对应公式中的Amax(DType)。 支持取值0.0和6.0-12.0,取值为0.0代表Amax(DType)为量化结果数据类型的最大值;取值为6.0-12.0代表Amax(DType)为传入值。 仅支持在FP4_E2M1和scale_alg为2时设置该值。 |
DOUBLE | - |
| y1 | 输出 | 输入x量化-1轴后的对应结果,对应公式中的Pi和di。 shape和输入x一致。 |
FLOAT4_E2M1、FLOAT4_E1M2、FLOAT8_E4M3FN、FLOAT8_E5M2 | ND |
| mxscale1 | 输出 | -1轴每个分组对应的量化尺度,对应公式中的mxscale1和Sb。 shape为x的-1轴的值除以32向上取整,并对其进行偶数pad,pad填充值为0。 |
FLOAT8_E8M0 | ND |
| y2 | 输出 | 输入x量化-2轴后的对应结果,对应公式中的Pj和dj。 shape和输入x一致。 |
FLOAT4_E2M1、FLOAT4_E1M2、FLOAT8_E4M3FN、FLOAT8_E5M2 | ND |
| mxscale2 | 输出 | -2轴每个分组对应的量化尺度,对应公式中的mxscale2和Sb。 shape为x的-2轴的值除以32向上取整,并对其进行偶数pad,pad填充值为0。 mxscale2输出需要对每两行数据进行交织处理。 |
FLOAT8_E8M0 | ND |
约束说明
- 关于x、mxscale1、mxscale2的shape约束说明如下:
- x的维度应该大于等于2。
- rank(mxscale1) = rank(x) + 1。
- rank(mxscale2) = rank(x) + 1。
- mxscale1.shape[-2] = (ceil(x.shape[-1] / 32) + 2 - 1) / 2。
- mxscale2.shape[-3] = (ceil(x.shape[-2] / 32) + 2 - 1) / 2。
- mxscale1.shape[-1] = 2。
- mxscale2.shape[-1] = 2。
- 其他维度与输入x一致。
- 举例:输入x的shape为[B, M, N],目的数据类型为FP8类时,对应的y1和y2的shape为[B, M, N],mxscale1的shape为[B, M, (ceil(N/32)+2-1)/2, 2],mxscale2的shape为[B, (ceil(M/32)+2-1)/2, N, 2]。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口(V1) | test_aclnn_dynamic_mx_quant_with_dual_axis | 通过aclnnDynamicMxQuantWithDualAxis接口方式调用,支持scale_alg=0/1。 |
| aclnn接口(V2) | test_aclnn_dynamic_mx_quant_with_dual_axis_v2 | 通过aclnnDynamicMxQuantWithDualAxisV2接口方式调用,支持scale_alg=0/1/2。 |
| 图模式 | - | 通过算子IR构图方式调用DynamicMxQuantWithDualAxis算子。 |