量化计算类算子精度验证
1. 算子定义
定义特征: 低精度(整型)与高精度(浮点)之间的转换与计算
常见数据类型: INT4/INT8/FLOAT4/FLOAT8
2. 验证方法
根据输出类型选择比对方法:
| 输出类型 | 比对方法 | 验证脚本 |
|---|---|---|
| 整型输出(量化结果) | 单标杆比对 | quantization_check.py |
| 浮点输出(反量化结果) | 双标杆比对 | quantization_check.py |
3. 使用示例
3.1 整型输出(量化结果)
from scripts.quantization_check import check_quantization_with_level
# 执行量化算子(FP16 → INT8)
npu_output = run_quantize_on_npu() # dtype: int8
golden_output = run_quantize_on_cpu() # dtype: int8
result = check_quantization_with_level(
npu_output, golden_output,
input_dtype='float16',
output_dtype='int8'
)
assert result['is_pass'], f"量化精度不达标: max_abs_error={result['max_abs_error']}"
3.2 浮点输出(反量化结果)
from scripts.quantization_check import check_quantization_with_level
# 执行反量化算子(INT8 → FP16)
npu_output = run_dequantize_on_npu() # dtype: float16
golden_output = run_dequantize_on_cpu() # dtype: float16 (高精度实现)
third_party_output = run_dequantize_on_gpu() # dtype: float16
result = check_quantization_with_level(
npu_output, golden_output, third_party_output,
precision_level='L1',
input_dtype='int8',
output_dtype='float16'
)
assert result['is_pass'], f"反量化精度不达标"
4. 通过标准
4.1 整型输出标准
核心标准: 绝对误差 ≤ 1
判定代码:
abs_error = np.abs(npu_output - golden_output)
max_abs_error = np.max(abs_error)
is_pass = (max_abs_error <= 1)
4.2 浮点输出标准
参考浮点计算类标准: 使用MARE/MERE/RMSE Ratio
精度等级阈值: 根据precision_level确定
- L0: MARE ratio ≤ 10, MERE ratio ≤ 2, RMSE ratio ≤ 2
- L1: MARE ratio ≤ 5, MERE ratio ≤ 1.5, RMSE ratio ≤ 1.5
- L2: MARE ratio ≤ 2, MERE ratio ≤ 1.2, RMSE ratio ≤ 1.2
5. 完整通过标准表
| 输入类型 | 输出类型 | 通过标准 | 验证方法 |
|---|---|---|---|
| 整型(INT4/INT8/INT16等) | 整型(INT4/INT8/INT16等) | N/A(不常见场景) | - |
| 整型(INT4/INT8/INT16等) | 浮点(FLOAT4/FLOAT8/FLOAT16/FLOAT32等) | 参考浮点类标准 | 双标杆比对 |
| 浮点(FLOAT4/FLOAT8/FLOAT16/FLOAT32等) | 整型(INT4/INT8/INT16等) | 绝对误差 ≤ 1 | 单标杆比对 |
| 浮点(FLOAT4/FLOAT8/FLOAT16/FLOAT32等) | 浮点(FLOAT4/FLOAT8/FLOAT16/FLOAT32等) | 参考浮点类标准 | 双标杆比对 |
6. 脚本详细使用
6.1 check_quantization函数
from scripts.quantization_check import check_quantization
# 基础验证(自动判断输出类型)
result = check_quantization(
npu_output,
golden_output,
third_party_output=None, # 浮点输出时需要
input_dtype='float16',
output_dtype='int8'
)
# 返回结果(整型输出)
# {
# 'is_pass': True/False,
# 'max_abs_error': 1,
# 'mean_abs_error': 0.5,
# 'threshold': 1,
# 'comparison_method': 'single_benchmark'
# }
# 返回结果(浮点输出)
# {
# 'comparison_method': 'dual_benchmark',
# 'mare_npu': 0.001,
# 'mere_npu': 0.0005,
# 'rmse_npu': 0.001,
# 'mare_third': 0.0008,
# 'mere_third': 0.0004,
# 'rmse_third': 0.0008,
# 'mare_ratio': 1.25,
# 'mere_ratio': 1.25,
# 'rmse_ratio': 1.25,
# 'is_pass': None # 需要外部指定精度等级后判断
# }
6.2 check_quantization_with_level函数
from scripts.quantization_check import check_quantization_with_level
# 完整验证(包含精度等级判定)
result = check_quantization_with_level(
npu_output, golden_output, third_party_output,
precision_level='L1',
input_dtype='int8',
output_dtype='float16'
)
# 返回结果
# {
# 'is_pass': True/False,
# 'precision_level': 'L1',
# 'mare_ratio': 1.25,
# 'mere_ratio': 1.25,
# 'rmse_ratio': 1.25,
# 'mare_pass': True,
# 'mere_pass': True,
# 'rmse_pass': True,
# 'thresholds_used': {'mare_ratio': 5, 'mere_ratio': 1.5, 'rmse_ratio': 1.5}
# }
7. 参考文档
- 浮点计算类标准:见
float_compute.md - 详细精度标准:见
golden/COMMERCIAL_OPS_PRECISION_DOCS.md - 标杆构造方法:见
benchmark_construction.md