"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2026 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import torch
from unittest.mock import MagicMock
from msmodelslim.ir.qal import QStorage, QDType, QParam
from msmodelslim.core.quantizer.base import QConfig
from msmodelslim.utils.exception import SchemaValidateError, UnexpectedError
from msmodelslim.core.quantizer.impl.dualscale import MXWeightDualScaleMinmax, MXActDualScaleMinmax
class TestMXWeightDualScaleMinmax:
"""测试 Weight 双级缩放 MinMax 量化器 (512外层大块场景)"""
@pytest.fixture
def standard_config(self):
"""标准 512 大块量化配置"""
return QConfig(
dtype="mxfp4",
scope="per_block",
method="dualscale",
symmetric=True,
ext={"axes": -1, "dual_block_size": 512},
)
def test_initialization_success_when_valid_config(self, standard_config):
"""测试正常初始化时各项参数挂载符合预期"""
quantizer = MXWeightDualScaleMinmax(standard_config)
assert quantizer.axes == -1
assert quantizer.dual_block_size == 512
assert quantizer.w_q_storage is None
assert quantizer.w_q_param.ext["dual_block_size"] == 512
def test_initialization_raise_schema_validate_error_when_invalid_axes_type(self):
"""测试异常边界:axes 传入非整数或非列表时抛出 SchemaValidateError"""
invalid_config = QConfig(
dtype="mxfp4",
scope="per_block",
method="dualscale",
symmetric=True,
ext={"axes": "invalid_axis_string", "dual_block_size": 512},
)
with pytest.raises(SchemaValidateError, match="Invalid value for 'axes'"):
MXWeightDualScaleMinmax(invalid_config)
@pytest.mark.parametrize(
"weight_tensor, expected_dual_scale",
[
(torch.zeros(1, 512), torch.zeros(1, 1)),
(
torch.cat(
[torch.linspace(-12.0, 12.0, 512).unsqueeze(0), torch.linspace(-24.0, 24.0, 512).unsqueeze(0)],
dim=-1,
),
torch.tensor([[[2.0], [4.0]]]),
),
],
)
def test_dual_scale_calculate_correct_scale_when_init_weight_with_boundary_values(
self, standard_config, weight_tensor, expected_dual_scale
):
"""测试不同极值边界下,init_weight 能够通过 Observer 准确计算出分块全局的 dual_scale"""
quantizer = MXWeightDualScaleMinmax(standard_config)
quantizer.inner_quantizer.init_weight = MagicMock()
quantizer.inner_quantizer.get_q_param = MagicMock(
return_value=QParam(scheme=standard_config.to_scheme(), ext={"axes": [1]})
)
quantizer.inner_quantizer.get_q_storage = MagicMock(return_value=QStorage(QDType.FLOAT, weight_tensor))
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
quantizer.init_weight(weight_storage)
assert "dual_scale" in quantizer.w_q_param.ext
calculated_dual_scale = quantizer.w_q_param.ext["dual_scale"]
assert torch.allclose(calculated_dual_scale.flatten(), expected_dual_scale.flatten(), atol=1e-5)
def test_forward_raise_unexpected_error_when_dual_scale_is_none(self, standard_config):
"""测试异常边界:若未执行权重初始化直接调用 forward,内部因拿不到 dual_scale 抛出 UnexpectedError"""
quantizer = MXWeightDualScaleMinmax(standard_config)
quantizer.w_q_param.ext["axes"] = [1]
quantizer.inner_quantizer.forward = MagicMock(return_value=torch.randn(1, 512))
with pytest.raises(UnexpectedError, match="The parameter 'dual_scale' cannot be None"):
quantizer(x=None)
def test_forward_return_correct_dequant_value_when_valid_inner_dequant_provided(self, standard_config):
"""测试前向传播:完成初始化后,forward 能够正确执行轴变换并将内层去量化值乘以外部全局大尺度"""
quantizer = MXWeightDualScaleMinmax(standard_config)
weight_tensor = torch.linspace(-24.0, 24.0, 512).unsqueeze(0)
weight_storage = QStorage(QDType.FLOAT, weight_tensor)
mock_inner_dequant = torch.ones(1, 512)
quantizer.inner_quantizer.forward = MagicMock(return_value=mock_inner_dequant)
quantizer.inner_quantizer.get_q_param = MagicMock(
return_value=QParam(scheme=standard_config.to_scheme(), ext={"axes": [1]})
)
quantizer.inner_quantizer.get_q_storage = MagicMock(return_value=weight_storage)
quantizer.init_weight(weight_storage)
output = quantizer(x=None)
expected_output = mock_inner_dequant * 4.0
assert torch.allclose(output, expected_output, atol=1e-5)
def test_get_q_storage_raise_unexpected_error_when_storage_is_not_initialized(self, standard_config):
"""测试仓储获取异常:未做初始化时调取 w_q_storage 抛出异常"""
quantizer = MXWeightDualScaleMinmax(standard_config)
with pytest.raises(UnexpectedError, match="self.w_q_storage' cannot be None"):
quantizer.get_q_storage()
def test_get_q_param_raise_unexpected_error_when_param_is_none(self, standard_config):
"""测试参数获取异常:当 w_q_param 为空时引发 UnexpectedError"""
quantizer = MXWeightDualScaleMinmax(standard_config)
quantizer.w_q_param = None
with pytest.raises(UnexpectedError, match="self.w_q_param' cannot be None"):
quantizer.get_q_param()
class TestMXActDualScaleMinmax:
"""测试 Activation 双级缩放 MinMax 量化器"""
@pytest.fixture
def act_config(self):
"""标准激活层配置"""
return QConfig(
dtype="mxfp4",
scope="per_block",
method="dualscale",
symmetric=True,
ext={"axes": -1, "dual_block_size": 512},
)
def test_initialization_success_when_valid_config(self, act_config):
"""测试激活量化器初始化成功后,各核心内部状态、配置树及 data_free 属性的正确性"""
quantizer = MXActDualScaleMinmax(act_config)
assert quantizer.axes == -1
assert quantizer.dual_block_size == 512
assert quantizer.is_data_free() is True
def test_initialization_raise_schema_validate_error_when_invalid_axes_type(self):
"""测试异常边界:初始化激活量化器时,传入非法非法的 axes 类型应抛出异常"""
invalid_config = QConfig(
dtype="mxfp4",
scope="per_block",
method="dualscale",
symmetric=True,
ext={"axes": {}, "dual_block_size": 512},
)
with pytest.raises(SchemaValidateError, match="Invalid value for 'axes'"):
MXActDualScaleMinmax(invalid_config)
@pytest.mark.parametrize("boundary_tensor", [torch.randn(1, 1024), torch.randn(2, 512), torch.tensor([])])
def test_forward_return_original_tensor_when_any_boundary_values_passed(self, act_config, boundary_tensor):
"""测试前向传播:根据 Data-Free 设计,激活层 forward 为透传逻辑,任何张量均原样返回"""
quantizer = MXActDualScaleMinmax(act_config)
output = quantizer(boundary_tensor)
assert output is boundary_tensor
def test_get_q_param_return_fallback_scheme_when_q_param_is_none(self, act_config):
"""测试参数安全读取:若 q_param 被意外置空,应通过 config 降级返回基础的量化 Scheme 结构"""
quantizer = MXActDualScaleMinmax(act_config)
quantizer.q_param = None
fallback_param = quantizer.get_q_param()
assert fallback_param is not None
assert fallback_param.scheme == act_config.to_scheme()