"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import torch
from msmodelslim.ir.qal import QStorage, QDType
from msmodelslim.core.quantizer.base import QConfig
from msmodelslim.core.quantizer.impl.fouroversix import WeightFouroverSixQuantizer
class TestWeightFouroverSixQuantizer:
"""测试 Weight FouroverSix 量化器"""
@pytest.fixture
def standard_config(self):
"""标准 mxfp4 per_block fouroversix 量化配置"""
return QConfig(dtype="mxfp4", scope="per_block", method="fouroversix", symmetric=True, ext={"axes": -1})
def test_initialization_success_when_valid_config(self, standard_config):
"""测试正常初始化时各项参数挂载符合预期"""
quantizer = WeightFouroverSixQuantizer(standard_config)
assert quantizer.axes == -1
assert quantizer.block_size == 32
assert quantizer.w_q_storage is None
assert quantizer.w_q_param.ext["block_size"] == 32
assert quantizer.w_q_param.ext["axes"] == -1
assert quantizer.w_q_param.ext["scale"] is None
@pytest.mark.parametrize(
"axes_value",
[
-1,
0,
1,
[-1],
[0, 1],
],
)
def test_initialization_success_with_various_axes(self, axes_value):
"""测试不同 axes 配置下的初始化"""
config = QConfig(
dtype="mxfp4", scope="per_block", method="fouroversix", symmetric=True, ext={"axes": axes_value}
)
quantizer = WeightFouroverSixQuantizer(config)
assert quantizer.axes == axes_value
@pytest.mark.parametrize(
"weight_tensor",
[
torch.zeros(1, 32),
torch.randn(1, 64),
torch.randn(2, 32),
],
)
def test_init_weight_success_with_boundary_values(self, standard_config, weight_tensor):
"""测试不同边界条件下的权重初始化"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
assert quantizer.w_q_storage is not None
assert quantizer.w_q_param.ext["scale"] is not None
assert "scale" in quantizer.w_q_param.ext
def test_init_weight_selects_best_scale_based_on_mse(self, standard_config):
"""测试 init_weight 能够根据 MSE 选择最优缩放方案"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_tensor = torch.linspace(-6.0, 6.0, 32).unsqueeze(0)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
assert quantizer.w_q_storage is not None
assert quantizer.w_q_param.ext["scale"] is not None
selected_scale = quantizer.w_q_param.ext["scale"]
assert selected_scale.dim() > 0
def test_init_weight_with_zero_tensor(self, standard_config):
"""测试全零权重张量的量化处理"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_tensor = torch.zeros(1, 32)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
assert quantizer.w_q_storage is not None
assert quantizer.w_q_param.ext["scale"] is not None
def test_forward_return_dequantized_value(self, standard_config):
"""测试前向传播能够正确反量化权重"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_tensor = torch.linspace(-6.0, 6.0, 32).unsqueeze(0)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
output = quantizer(x=None)
assert output is not None
assert output.shape == weight_tensor.shape
def test_forward_raise_error_when_not_initialized(self, standard_config):
"""测试异常边界:未初始化权重时调用 forward 会报错"""
quantizer = WeightFouroverSixQuantizer(standard_config)
with pytest.raises(Exception):
quantizer(x=None)
def test_get_q_storage_return_valid_storage_after_init(self, standard_config):
"""测试 get_q_storage 在初始化后返回有效的 QStorage"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_tensor = torch.randn(1, 32)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
q_storage = quantizer.get_q_storage()
assert q_storage is not None
assert isinstance(q_storage, QStorage)
def test_get_q_storage_return_none_before_init(self, standard_config):
"""测试 get_q_storage 在初始化前返回 None"""
quantizer = WeightFouroverSixQuantizer(standard_config)
q_storage = quantizer.get_q_storage()
assert q_storage is None
def test_get_q_param_return_valid_param(self, standard_config):
"""测试 get_q_param 返回有效的 QParam"""
quantizer = WeightFouroverSixQuantizer(standard_config)
q_param = quantizer.get_q_param()
assert q_param is not None
assert q_param.scheme.dtype.name == "MXFP4"
assert q_param.scheme.scope.name == "PER_BLOCK"
@pytest.mark.parametrize(
"input_scale, expected_exp",
[
(torch.tensor(2.0), torch.tensor(1.0)),
(torch.tensor(3.0), torch.tensor(2.0)),
(torch.tensor(1.5), torch.tensor(1.0)),
(torch.tensor(2.0**0.5), torch.tensor(1.0)),
(torch.tensor(2.0**1.5), torch.tensor(1.0)),
(torch.tensor(0.0), torch.tensor(0.0)),
(torch.tensor(-1.0), torch.tensor(0.0)),
],
)
def test_nearest_neighbor_rounding_to_e8m0(self, standard_config, input_scale, expected_exp):
"""测试 e8m0 舍入函数的正确性"""
quantizer = WeightFouroverSixQuantizer(standard_config)
result = quantizer._WeightFouroverSixQuantizer__nearest_neighbor_rounding_to_e8m0(input_scale)
assert torch.isclose(result, expected_exp, atol=1e-5)
def test_mse_based_selection_between_two_scales(self, standard_config):
"""测试基于 MSE 的双缩放方案选择机制"""
quantizer = WeightFouroverSixQuantizer(standard_config)
weight_tensor = torch.cat([torch.linspace(-6.0, 6.0, 32), torch.linspace(-3.0, 3.0, 32)]).unsqueeze(0)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
assert quantizer.w_q_storage is not None
selected_scale = quantizer.w_q_param.ext["scale"]
assert selected_scale.shape[1] == 2
assert selected_scale[0, 0] != selected_scale[0, 1], "两个块应该根据 MSE 选择不同的缩放方案"