DynamicMxQuantWithDualAxis

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:在-1轴和-2轴上同时进行目的数据类型为FLOAT4类、FLOAT8类的MX量化。在给定的-1轴和-2轴上,每32个数,计算出这两组数对应的量化尺度mxscale1、mxscale2,然后分别对两组数所有元素除以对应的mxscale1或mxscale2,根据round_mode转换到对应的dst_type,得到量化结果y1和y2。在dst_type为FLOAT8_E4M3FN、FLOAT8_E5M2时,根据scale_alg的取值来指定计算mxscale的不同算法。

  • 合轴说明:算子实现时,会对-2轴(不包含)之前的所有轴进行合轴处理。即对于输入shape为(d0,d1,...,dn−3,dn−2,dn−1)(d_0, d_1, ..., d_{n-3}, d_{n-2}, d_{n-1})的张量,-2轴之前的维度(d0,d1,...,dn−3)(d_0, d_1, ..., d_{n-3})会被合并为一个维度,等效于将输入reshape为(d0×d1×...×dn−3,dn−2,dn−1)(d_0 \times d_1 \times ... \times d_{n-3}, d_{n-2}, d_{n-1})后再进行量化计算。

  • 计算公式:

    • 场景1,当scale_alg为0时,即OCP Microscaling Formats (Mx) Specification实现:

    • 将输入x在-1轴上按照32个数进行分组,一组32个数 {{Vi}i=132}\{\{V_i\}_{i=1}^{32}\} 量化为 {mxscale1,{Pi}i=132}\{mxscale1, \{P_i\}_{i=1}^{32}\}

      shared_exp=floor(log2(maxi(∣Vi∣)))−emaxshared\_exp = floor(log_2(max_i(|V_i|))) - emax

      mxscale1=2shared_expmxscale1 = 2^{shared\_exp}

      Pi=cast_to_dst_type(Vi/mxscale1,round_mode), i from 1 to 32P_i = cast\_to\_dst\_type(V_i/mxscale1, round\_mode), \space i\space from\space 1\space to\space 32

    • 同时,将输入x在-2轴上按照32个数进行分组,一组32个数 {{Vj}j=132}\{\{V_j\}_{j=1}^{32}\} 量化为 {mxscale2,{Pj}j=132}\{mxscale2, \{P_j\}_{j=1}^{32}\}

      shared_exp=floor(log2(maxj(∣Vj∣)))−emaxshared\_exp = floor(log_2(max_j(|V_j|))) - emax

      mxscale2=2shared_expmxscale2 = 2^{shared\_exp}

      Pj=cast_to_dst_type(Vj/mxscale2,round_mode), j from 1 to 32P_j = cast\_to\_dst\_type(V_j/mxscale2, round\_mode), \space j\space from\space 1\space to\space 32

    • -1轴​量化后的 PiP_{i} 按对应的 ViV_{i} 的位置组成输出y1,mxscale1按对应的-1轴维度上的分组组成输出mxscale1。-2轴​量化后的 PjP_{j} 按对应的 VjV_{j} 的位置组成输出y2,mxscale2按对应的-2轴维度上的分组组成输出mxscale2。

    • emax: 对应数据类型的最大正则数的指数位。

      DataType emax
      FLOAT4_E2M1 2
      FLOAT4_E1M2 0
      FLOAT8_E4M3FN 8
      FLOAT8_E5M2 15
    • 场景2,当scale_alg为1时,只涉及FP8类型(CuBALS Scale计算算法):

      • 将输入x在-1轴(或-2轴)上按照32个数进行分组,对每组单独计算块缩放因子Sfp32bS_{fp32}^b,再把组内所有元素映射到目标低精度类型FP8。
      • 找到该组中数值的最大绝对值:Amax(Dfp32b)=max({∣di∣}i=132)Amax(D_{fp32}^b) = max(\{|d_i|\}_{i=1}^{32})
      • 将FP32映射到目标数据类型FP8可表示的范围内:Sfp32b=Amax(Dfp32b)Amax(DType)S_{fp32}^b = \frac{Amax(D_{fp32}^b)}{Amax(DType)}
      • 从块缩放因子中提取无偏指数EintbE_{int}^b和尾数MfixpbM_{fixp}^b,并进行条件向上取整。
      • 计算块缩放因子:Sue8m0b=2EintbS_{ue8m0}^b=2^{E_{int}^b},计算块转换因子:Rfp32b=1fp32(Sue8m0b)R_{fp32}^b=\frac{1}{fp32(S_{ue8m0}^b)}
      • 对每个组内元素应用量化:di=DType(dfp32i⋅Rfp32b)d^i = DType(d_{fp32}^i \cdot R_{fp32}^b)
    • 场景3,当scale_alg为2时,只涉及FP4_E2M1类型(DynamicDtypeRange算法,仅V2接口支持):

      • 当dstTypeMax = 0.0/6.0/7.0时,使用指数域addValueBit进位法计算scale。
      • 当dstTypeMax为其他自定义值时,使用FP32精度invDstTypeMax乘法法计算scale。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 待量化数据,对应公式中Vi和di
目的类型为FLOAT4_E2M1、FLOAT4_E1M2时,x的最后一维必须是偶数。
FLOAT16、BFLOAT16 ND
round_mode 可选属性 数据转换的模式。
当dst_type为40/41(FLOAT4_E2M1/FLOAT4_E1M2)时,支持{"rint", "floor", "round"};
当dst_type为35/36(FLOAT8_E5M2/FLOAT8_E4M3FN)时,仅支持{"rint"};
传入空指针时,采用"rint"模式。
STRING -
dst_type 可选属性 指定数据转换后y1和y2的类型。
输入范围为{35, 36, 40, 41},分别对应{35:FLOAT8_E5M2, 36:FLOAT8_E4M3FN, 40:FLOAT4_E2M1, 41:FLOAT4_E1M2}。
INT64 -
scale_alg 可选属性 mxscale1和mxscale2的计算方法。
支持取值0、1和2,取值为0代表OCP实现(场景1),为1代表CuBALS实现(场景2),为2代表DynamicDtypeRange实现(场景3)。
当dst_type为FLOAT4_E1M2时仅支持取值为0。
当dst_type为FLOAT4_E2M1时仅支持取值为0和2。
当dst_type为FLOAT8时仅支持取值为0和1。
INT64 -
dst_type_max 可选属性 maxType的取值,对应公式中的Amax(DType)。
支持取值0.0和6.0-12.0,取值为0.0代表Amax(DType)为量化结果数据类型的最大值;取值为6.0-12.0代表Amax(DType)为传入值。
仅支持在FP4_E2M1和scale_alg为2时设置该值。
DOUBLE -
y1 输出 输入x量化-1轴后的对应结果,对应公式中的Pi和di
shape和输入x一致。
FLOAT4_E2M1、FLOAT4_E1M2、FLOAT8_E4M3FN、FLOAT8_E5M2 ND
mxscale1 输出 -1轴每个分组对应的量化尺度,对应公式中的mxscale1和Sb
shape为x的-1轴的值除以32向上取整,并对其进行偶数pad,pad填充值为0。
FLOAT8_E8M0 ND
y2 输出 输入x量化-2轴后的对应结果,对应公式中的Pj和dj
shape和输入x一致。
FLOAT4_E2M1、FLOAT4_E1M2、FLOAT8_E4M3FN、FLOAT8_E5M2 ND
mxscale2 输出 -2轴每个分组对应的量化尺度,对应公式中的mxscale2和Sb
shape为x的-2轴的值除以32向上取整,并对其进行偶数pad,pad填充值为0。
mxscale2输出需要对每两行数据进行交织处理。
FLOAT8_E8M0 ND

约束说明

  • 关于x、mxscale1、mxscale2的shape约束说明如下:
    • x的维度应该大于等于2。
    • rank(mxscale1) = rank(x) + 1。
    • rank(mxscale2) = rank(x) + 1。
    • mxscale1.shape[-2] = (ceil(x.shape[-1] / 32) + 2 - 1) / 2。
    • mxscale2.shape[-3] = (ceil(x.shape[-2] / 32) + 2 - 1) / 2。
    • mxscale1.shape[-1] = 2。
    • mxscale2.shape[-1] = 2。
    • 其他维度与输入x一致。
    • 举例:输入x的shape为[B, M, N],目的数据类型为FP8类时,对应的y1和y2的shape为[B, M, N],mxscale1的shape为[B, M, (ceil(N/32)+2-1)/2, 2],mxscale2的shape为[B, (ceil(M/32)+2-1)/2, N, 2]。

调用说明

调用方式 样例代码 说明
aclnn接口(V1) test_aclnn_dynamic_mx_quant_with_dual_axis 通过aclnnDynamicMxQuantWithDualAxis接口方式调用,支持scale_alg=0/1。
aclnn接口(V2) test_aclnn_dynamic_mx_quant_with_dual_axis_v2 通过aclnnDynamicMxQuantWithDualAxisV2接口方式调用,支持scale_alg=0/1/2。
图模式 - 通过算子IR构图方式调用DynamicMxQuantWithDualAxis算子。