"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import torch
from msmodelslim import ir as qir
from msmodelslim.core.quantizer.base import AutoWeightQuantizer, QConfig
from msmodelslim.core.quantizer.impl.gptq import (
WeightPerChannelGPTQ,
WeightPerGroupGPTQ,
add_batch,
calculate_hessian_inv,
get_ext_value,
)
from msmodelslim.ir.qal.qbase import QDType, QScope, QScheme, QStorage
from msmodelslim.utils.exception import SpecError
def to_qconfig(q_scheme: QScheme, method: str) -> QConfig:
q_config = QConfig(
dtype=q_scheme.dtype.value,
scope=q_scheme.scope.value,
symmetric=q_scheme.symmetric,
method=method,
)
if q_scheme.scope == QScope.PER_GROUP:
q_config.ext["group_size"] = 256
return q_config
class DummyConfig:
"""用于测试 get_ext_value 的简易配置对象"""
def __init__(self, ext):
self.ext = ext
class TestGetExtValue:
"""测试 get_ext_value 辅助函数"""
def test_return_default_when_config_is_none(self):
default = 0.1
assert get_ext_value(None, "percdamp", default) == default
def test_return_default_when_ext_is_none(self):
cfg = DummyConfig(ext=None)
default = 0.2
assert get_ext_value(cfg, "percdamp", default) == default
def test_return_value_from_ext(self):
cfg = DummyConfig(ext={"percdamp": 0.3})
assert get_ext_value(cfg, "percdamp", 0.1) == pytest.approx(0.3)
def test_ignore_none_in_ext_and_use_default(self):
cfg = DummyConfig(ext={"percdamp": None})
assert get_ext_value(cfg, "percdamp", 0.4) == pytest.approx(0.4)
class TestAddBatch:
"""测试 add_batch 辅助函数"""
def test_initialize_hessian_when_none(self):
x = torch.randn(4, 8)
hessian, nsamples = add_batch(None, 0.0, x)
assert hessian.shape == (8, 8)
assert nsamples == 4
assert torch.all(torch.isfinite(hessian))
assert torch.trace(hessian) > 0
def test_accumulate_hessian_and_nsamples(self):
x1 = torch.randn(4, 8)
hessian, nsamples = add_batch(None, 0.0, x1)
x2 = torch.randn(6, 8)
hessian2, nsamples2 = add_batch(hessian.clone(), nsamples, x2)
assert hessian2.shape == (8, 8)
assert nsamples2 == nsamples + 6
assert not torch.allclose(hessian, hessian2)
def test_support_3d_input(self):
x = torch.randn(2, 3, 8)
hessian, nsamples = add_batch(None, 0.0, x)
assert hessian.shape == (8, 8)
assert nsamples == 6
class TestCalculateHessianInv:
"""测试 calculate_hessian_inv 辅助函数"""
def test_return_positive_definite_like_inverse(self):
columns = 6
rows = 10
weight = torch.randn(rows, columns, dtype=torch.float32)
a = torch.randn(columns, columns)
hessian = a.t().matmul(a)
percdamp = 0.01
h_inv = calculate_hessian_inv(hessian.clone(), percdamp, weight)
assert h_inv.shape == (columns, columns)
assert h_inv.dtype == weight.dtype
diag = torch.diag(h_inv)
assert torch.all(torch.isfinite(diag))
assert torch.all(diag > 0)
def test_handle_dead_columns(self):
columns = 4
rows = 6
weight = torch.randn(rows, columns, dtype=torch.float32)
hessian = torch.eye(columns)
hessian[1, 1] = 0.0
percdamp = 0.01
h_inv = calculate_hessian_inv(hessian, percdamp, weight)
assert h_inv.shape == (columns, columns)
assert torch.all(weight[:, 1] == 0)
class TestWeightPerChannelGPTQ:
"""测试 Per-Channel GPTQ 量化器"""
def setup_class(self):
self.config = QConfig(
dtype="int8",
scope="per_channel",
method="gptq",
symmetric=True,
)
def test_initialization(self):
quantizer = WeightPerChannelGPTQ(self.config)
assert quantizer.config == self.config
assert quantizer.weight is None
assert quantizer.bias is None
assert quantizer.w_q_param is None
assert quantizer.w_q_storage is None
assert quantizer.hessian is None
assert quantizer.nsamples == 0
assert isinstance(quantizer.percdamp, float)
assert isinstance(quantizer.block_size, int)
def test_ext_overrides_percdamp_and_block_size(self):
cfg = QConfig(
dtype="int8",
scope="per_channel",
method="gptq",
symmetric=True,
)
cfg.ext["percdamp"] = 0.2
cfg.ext["block_size"] = 64
quantizer = WeightPerChannelGPTQ(cfg)
assert quantizer.percdamp == pytest.approx(0.2)
assert quantizer.block_size == 64
def test_forward_collects_hessian(self):
quantizer = WeightPerChannelGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
bias = torch.randn(8)
quantizer.init_weight(weight, bias)
x = torch.randn(2, 3, 8)
out = quantizer(x)
assert quantizer.hessian is not None
assert quantizer.hessian.shape == (8, 8)
assert quantizer.nsamples == 6
assert torch.allclose(out, weight.value)
def test_get_q_storage_without_hessian_raises(self):
quantizer = WeightPerChannelGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
quantizer.init_weight(weight)
with pytest.raises(SpecError, match="No hessian was set"):
_ = quantizer.get_q_storage()
def test_get_q_param_without_hessian_raises(self):
quantizer = WeightPerChannelGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
quantizer.init_weight(weight)
with pytest.raises(SpecError, match="No hessian was set"):
_ = quantizer.get_q_param()
def test_quantization_flow_can_get_q_storage_and_param(self):
quantizer = WeightPerChannelGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
bias = torch.randn(8)
quantizer.init_weight(weight, bias)
x = torch.randn(3, 8)
_ = quantizer(x)
q_storage = quantizer.get_q_storage()
q_param = quantizer.get_q_param()
assert q_storage is not None
assert q_param is not None
assert q_param.scheme == self.config.to_scheme()
assert "scale" in q_param.ext
assert "offset" in q_param.ext
@pytest.mark.parametrize(
"qconfig",
[
to_qconfig(qir.int8_per_channel_sym, "gptq"),
to_qconfig(qir.int8_per_channel_asym, "gptq"),
to_qconfig(qir.int4_per_channel_sym, "gptq"),
to_qconfig(qir.int4_per_channel_asym, "gptq"),
],
)
def test_creation_with_auto_quantizer(self, qconfig):
quantizer = AutoWeightQuantizer.from_config(qconfig)
assert isinstance(quantizer, WeightPerChannelGPTQ)
class TestWeightPerGroupGPTQ:
"""测试 Per-Group GPTQ 量化器"""
def setup_class(self):
self.config = QConfig(
dtype="int8",
scope="per_group",
method="gptq",
symmetric=True,
)
self.config.ext["group_size"] = 2
def test_initialization(self):
quantizer = WeightPerGroupGPTQ(self.config)
assert quantizer.config == self.config
assert quantizer.weight is None
assert quantizer.bias is None
assert quantizer.w_q_param is None
assert quantizer.w_q_storage is None
assert quantizer.hessian is None
assert quantizer.nsamples == 0
assert isinstance(quantizer.percdamp, float)
assert isinstance(quantizer.block_size, int)
assert isinstance(quantizer.group_size, int)
def test_forward_collects_hessian(self):
quantizer = WeightPerGroupGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
bias = torch.randn(8)
quantizer.init_weight(weight, bias)
x = torch.randn(2, 3, 8)
out = quantizer(x)
assert quantizer.hessian is not None
assert quantizer.hessian.shape == (8, 8)
assert quantizer.nsamples == 6
assert torch.allclose(out, weight.value)
def test_get_q_storage_without_hessian_raises(self):
quantizer = WeightPerGroupGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
quantizer.init_weight(weight)
with pytest.raises(SpecError, match="No hessian was set"):
_ = quantizer.get_q_storage()
def test_get_q_param_without_hessian_raises(self):
quantizer = WeightPerGroupGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
quantizer.init_weight(weight)
with pytest.raises(SpecError, match="No hessian was set"):
_ = quantizer.get_q_param()
def test_quantization_flow_can_get_q_storage_and_param_and_group_ext(self):
quantizer = WeightPerGroupGPTQ(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(4, 8))
bias = torch.randn(8)
quantizer.init_weight(weight, bias)
x = torch.randn(3, 8)
_ = quantizer(x)
q_storage = quantizer.get_q_storage()
q_param = quantizer.get_q_param()
assert q_storage is not None
assert q_param is not None
assert q_param.scheme == self.config.to_scheme()
assert "scale" in q_param.ext
assert "offset" in q_param.ext
assert "group_size" in q_param.ext
out_channels = weight.value.shape[0]
num_groups = weight.value.shape[1] // self.config.ext["group_size"]
assert q_param.ext["scale"].shape == (out_channels, num_groups)
assert q_param.ext["offset"].shape == (out_channels, num_groups)
@pytest.mark.parametrize(
"qconfig",
[
to_qconfig(qir.int8_per_group_sym, "gptq"),
to_qconfig(qir.int8_per_group_asym, "gptq"),
to_qconfig(qir.int4_per_group_sym, "gptq"),
to_qconfig(qir.int4_per_group_asym, "gptq"),
],
)
def test_creation_with_auto_quantizer(self, qconfig):
quantizer = AutoWeightQuantizer.from_config(qconfig)
assert isinstance(quantizer, WeightPerGroupGPTQ)