"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import torch
from pydantic import ValidationError
from msmodelslim.ir.api import calculate_qparam
from msmodelslim.ir.qal.qbase import QStorage, QDType, QScheme, QScope
from msmodelslim import ir as qir
from msmodelslim.core.observer import MsMinMaxObserver
from msmodelslim.core.quantizer.base import QConfig
from msmodelslim.core.quantizer.impl.ssz import (
WeightPerChannelSsz,
get_ext_scale,
get_ext_offset,
_broadcast_quantizer_state,
)
from msmodelslim.utils.distributed.task_scheduler import DTSMixin
from msmodelslim.utils.exception import SpecError
from unittest import mock
def to_qconfig(q_scheme: QScheme, method: str) -> QConfig:
q_config = QConfig(
dtype=q_scheme.dtype.value,
scope=q_scheme.scope.value,
symmetric=q_scheme.symmetric,
method=method,
)
if q_scheme.scope == QScope.PER_GROUP:
q_config.ext['group_size'] = 256
return q_config
class TestWeightPerChannelSsz:
"""测试Per-Channel ssz量化器"""
def setup_class(self):
self.config = QConfig(dtype="int8", scope="per_channel", method="ssz", symmetric=True)
def test_initialization(self):
"""测试初始化"""
quantizer = WeightPerChannelSsz(self.config)
assert quantizer.config == self.config, "config is not correct, expected: %s, actual: %s" % (
self.config,
quantizer.config,
)
assert quantizer.minmax_observer is not None, "minmax_observer is not correct, expected: %s, actual: %s" % (
MsMinMaxObserver,
quantizer.minmax_observer,
)
assert isinstance(quantizer.minmax_observer, MsMinMaxObserver), (
"minmax_observer is not correct, expected: %s, actual: %s"
% (MsMinMaxObserver, type(quantizer.minmax_observer))
)
assert quantizer.weight is None, "weight is not correct, expected: %s, actual: %s" % (None, quantizer.weight)
assert quantizer.bias is None, "bias is not correct, expected: %s, actual: %s" % (None, quantizer.bias)
assert quantizer.w_q_param is None, "w_q_param is not correct, expected: %s, actual: %s" % (
None,
quantizer.w_q_param,
)
assert quantizer.w_q_storage is None, "w_q_storage is not correct, expected: %s, actual: %s" % (
None,
quantizer.w_q_storage,
)
assert quantizer.is_quantized is False, "is_quantized is not correct, expected: %s, actual: %s" % (
False,
quantizer.is_quantized,
)
def test_init_weight_validation(self):
"""测试权重初始化验证"""
quantizer = WeightPerChannelSsz(self.config)
with pytest.raises(ValidationError, match="instance of QStorage"):
quantizer.init_weight(torch.randn(10, 20))
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
with pytest.raises(ValidationError, match="instance of Tensor"):
quantizer.init_weight(weight, bias="invalid")
def test_get_q_storage_and_q_param_after_forward(self):
"""测试在forward之后获取q_storage和q_param"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
bias = torch.randn(20)
quantizer.init_weight(weight, bias)
quantizer()
q_storage = quantizer.get_q_storage()
q_param = quantizer.get_q_param()
assert q_storage is not None, "q_storage is not correct, expected: %s, actual: %s" % (None, q_storage)
assert q_param is not None, "q_param is not correct, expected: %s, actual: %s" % (None, q_param)
def test_forward_before_init_weight(self):
"""测试在初始化权重之前前向传播"""
quantizer = WeightPerChannelSsz(self.config)
with pytest.raises(SpecError, match="No weight was set"):
quantizer()
def test_forward_after_init_weight(self):
"""测试权重初始化并前向传播"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
bias = torch.randn(20)
quantizer.init_weight(weight, bias)
assert quantizer.weight == weight, "weight is not correct, expected: %s, actual: %s" % (
weight,
quantizer.weight,
)
assert quantizer.bias is bias, "bias is not correct, expected: %s, actual: %s" % (bias, quantizer.bias)
result = quantizer()
q_param = quantizer.get_q_param()
assert q_param is not None, "q_param is not correct, expected: %s, actual: %s" % (None, q_param)
assert q_param.scheme == self.config.to_scheme(), "q_param.scheme is not correct, expected: %s, actual: %s" % (
self.config.to_scheme(),
q_param.scheme,
)
assert isinstance(q_param.ext, dict), "q_param.ext is not correct, expected: %s, actual: %s" % (
dict,
type(q_param.ext),
)
assert "scale" in q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"scale",
q_param.ext.keys(),
)
assert "offset" in q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"offset",
q_param.ext.keys(),
)
assert isinstance(q_param.ext["scale"], torch.Tensor), (
"q_param.ext['scale'] is not correct, expected: %s, actual: %s" % (torch.Tensor, type(q_param.ext["scale"]))
)
assert isinstance(q_param.ext["offset"], torch.Tensor), (
"q_param.ext['offset'] is not correct, expected: %s, actual: %s"
% (torch.Tensor, type(q_param.ext["offset"]))
)
assert q_param.ext["scale"].shape == (weight.value.shape[0],), (
"q_param.ext['scale'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), q_param.ext["scale"].shape)
)
assert q_param.ext["offset"].shape == (weight.value.shape[0],), (
"q_param.ext['offset'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), q_param.ext["offset"].shape)
)
q_storage = quantizer.get_q_storage()
assert q_storage is not None, "q_storage is not correct, expected: %s, actual: %s" % (None, q_storage)
assert result.shape == weight.value.shape, "result.shape is not correct, expected: %s, actual: %s" % (
weight.value.shape,
result.shape,
)
def test_forward_with_invalid_one_dim_weight(self):
"""测试无效权重形状"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(10))
bias = torch.randn(10)
quantizer.init_weight(weight, bias)
with pytest.raises(SpecError, match="Weight must be a 2D tensor"):
quantizer()
def test_forward_with_invalid_three_dim_weight(self):
"""测试无效权重形状"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(10, 20, 30))
bias = torch.randn(20)
quantizer.init_weight(weight, bias)
with pytest.raises(SpecError, match="Weight must be a 2D tensor"):
quantizer()
def test_different_weight_shapes(self):
"""测试不同权重形状的处理"""
weight_shapes = [(10, 20), (32, 64), (128, 256)]
for shape in weight_shapes:
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(*shape))
bias = torch.randn(shape[1])
quantizer.init_weight(weight, bias)
result = quantizer()
q_param = quantizer.get_q_param()
assert result.shape == weight.value.shape, "result.shape is not correct, expected: %s, actual: %s" % (
weight.value.shape,
result.shape,
)
assert q_param is not None, "q_param is not correct, expected: %s, actual: %s" % (None, q_param)
assert q_param.scheme == self.config.to_scheme(), (
"q_param.scheme is not correct, expected: %s, actual: %s" % (self.config.to_scheme(), q_param.scheme)
)
assert q_param.ext["scale"].shape == (shape[0],), (
"q_param.ext['scale'].shape is not correct, expected: %s, actual: %s"
% ((shape[0],), q_param.ext["scale"].shape)
)
assert q_param.ext["offset"].shape == (shape[0],), (
"q_param.ext['offset'].shape is not correct, expected: %s, actual: %s"
% ((shape[0],), q_param.ext["offset"].shape)
)
@pytest.mark.parametrize(
"qconfig",
[
to_qconfig(qir.int8_per_channel_sym, "ssz"),
],
)
def test_creation_with_auto_quantizer(self, qconfig):
"""测试通过自动量化器创建"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer
quantizer = AutoWeightQuantizer.from_config(qconfig)
assert isinstance(quantizer, WeightPerChannelSsz)
def test_dts_mixin_inheritance(self):
"""验证 WeightPerChannelSsz 通过 AutoWeightQuantizer 继承 DTSMixin"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer as AWQ
quantizer = WeightPerChannelSsz(self.config)
assert isinstance(quantizer, DTSMixin), "WeightPerChannelSsz should be a DTSMixin instance"
assert hasattr(quantizer, "distributed_sync"), "WeightPerChannelSsz should have distributed_sync"
assert WeightPerChannelSsz.distributed_sync is not AWQ.distributed_sync, (
"WeightPerChannelSsz should override distributed_sync"
)
def test_dts_mixin_default_forward(self):
"""验证 AutoWeightQuantizer 默认 distributed_sync 触发 forward"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer as AWQ
assert issubclass(AWQ, DTSMixin), "AutoWeightQuantizer should inherit DTSMixin"
wq = WeightPerChannelSsz(self.config)
assert isinstance(wq, DTSMixin), "All weight quantizers should inherit DTSMixin via AutoWeightQuantizer"
class TestWeightPerChannelSszDistributedSync:
"""测试 WeightPerChannelSsz 的分布式同步行为"""
def setup_class(self):
self.config = QConfig(dtype="int8", scope="per_channel", method="ssz", symmetric=True)
self.asymmetric_config = QConfig(dtype="int8", scope="per_channel", method="ssz", symmetric=False)
def test_load_quantized_from_broadcast_tensors_should_set_quantized_state(self):
"""验证 load_quantized_from_broadcast_tensors 正确重建量化状态"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
bias = torch.randn(20)
quantizer.init_weight(weight, bias)
quantizer()
ref_scale = get_ext_scale(quantizer.w_q_param)
ref_offset = get_ext_offset(quantizer.w_q_param)
ref_storage_value = quantizer.w_q_storage.value
ref_storage_dtype = quantizer.w_q_storage.dtype
new_quantizer = WeightPerChannelSsz(self.config)
new_quantizer.load_quantized_from_broadcast_tensors(
scale=ref_scale,
offset=ref_offset,
w_q_storage_value=ref_storage_value,
q_storage_dtype=ref_storage_dtype,
)
assert new_quantizer.is_quantized is True, "quantizer should be marked as quantized"
assert new_quantizer.weight is None, "weight should be released after loading quantized state"
assert new_quantizer.w_q_param is not None, "w_q_param should be set"
assert new_quantizer.w_q_storage is not None, "w_q_storage should be set"
assert torch.equal(get_ext_scale(new_quantizer.w_q_param), ref_scale), "scale should match reference"
assert torch.equal(get_ext_offset(new_quantizer.w_q_param), ref_offset), "offset should match reference"
assert torch.equal(new_quantizer.w_q_storage.value, ref_storage_value), (
"quantized weight should match reference"
)
assert new_quantizer.w_q_storage.dtype == ref_storage_dtype, "storage dtype should match reference"
assert new_quantizer.w_q_param.scheme == self.config.to_scheme(), "scheme should match config"
ref_output = quantizer.forward(None)
new_output = new_quantizer.forward(None)
assert torch.equal(ref_output, new_output), "forward output should be identical"
def test_load_quantized_from_broadcast_tensors_releases_old_weight(self):
"""验证 load_quantized_from_broadcast_tensors 释放原有 weight"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(5, 10))
bias = torch.randn(10)
quantizer.init_weight(weight, bias)
scale = torch.randn(5)
offset = torch.zeros(5)
storage_value = torch.randint(-8, 7, (5, 10), dtype=torch.int8)
quantizer.load_quantized_from_broadcast_tensors(
scale=scale,
offset=offset,
w_q_storage_value=storage_value,
q_storage_dtype=QDType.INT8,
)
assert quantizer.weight is None, "weight should be released after loading broadcast state"
def test_load_quantized_from_broadcast_tensors_with_asymmetric(self):
"""验证非对称量化下 load_quantized_from_broadcast_tensors 正确工作"""
quantizer = WeightPerChannelSsz(self.asymmetric_config)
weight = QStorage(QDType.FLOAT, torch.randn(8, 16))
bias = torch.randn(16)
quantizer.init_weight(weight, bias)
quantizer()
ref_scale = get_ext_scale(quantizer.w_q_param)
ref_offset = get_ext_offset(quantizer.w_q_param)
assert not (torch.all(ref_offset == 0)), "asymmetric quantization should have non-zero offsets"
new_quantizer = WeightPerChannelSsz(self.asymmetric_config)
new_quantizer.load_quantized_from_broadcast_tensors(
scale=ref_scale,
offset=ref_offset,
w_q_storage_value=quantizer.w_q_storage.value,
q_storage_dtype=quantizer.w_q_storage.dtype,
)
assert new_quantizer.w_q_param.scheme.symmetric is False, "scheme should be asymmetric"
assert torch.equal(get_ext_offset(new_quantizer.w_q_param), ref_offset), (
"offset should match in asymmetric mode"
)
def test_broadcast_quantizer_state_raises_when_not_quantized(self):
"""验证未量化时 _broadcast_quantizer_state 抛出 RuntimeError"""
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(5, 10))
quantizer.init_weight(weight, bias=None)
with mock.patch("msmodelslim.core.quantizer.impl.ssz.dist") as mock_dist:
mock_dist.get_rank.return_value = 0
mock_dist.is_initialized.return_value = True
with pytest.raises(SpecError, match="Cannot broadcast quantizer state"):
_broadcast_quantizer_state(quantizer, owner_rank=0)
@mock.patch("msmodelslim.core.quantizer.impl.ssz.dist")
@mock.patch("msmodelslim.core.quantizer.impl.ssz.broadcast_tensor_process_group_safe")
def test_broadcast_quantizer_state_sends_scale_offset_storage(self, mock_broadcast, mock_dist):
"""验证 owner rank 正确 broadcast 三个张量"""
mock_dist.get_rank.return_value = 0
mock_dist.is_initialized.return_value = True
mock_dist.get_world_size.return_value = 2
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(5, 10))
quantizer.init_weight(weight, bias=None)
quantizer()
_broadcast_quantizer_state(quantizer, owner_rank=0)
mock_dist.broadcast_object_list.assert_called_once()
assert mock_broadcast.call_count == 3, "should broadcast 3 tensors: scale, offset, storage"
for call_arg in mock_broadcast.call_args_list:
assert call_arg.kwargs.get('src') == 0, "all broadcasts should use owner_rank=0 as src"
@mock.patch("msmodelslim.core.quantizer.impl.ssz.dist")
@mock.patch("msmodelslim.core.quantizer.impl.ssz.broadcast_tensor_process_group_safe")
def test_broadcast_quantizer_state_non_owner_calls_load(self, mock_broadcast, mock_dist):
"""验证非 owner rank 接收 broadcast 后调用 load_quantized_from_broadcast_tensors"""
mock_dist.get_rank.return_value = 1
mock_dist.is_initialized.return_value = True
mock_dist.get_world_size.return_value = 2
fake_meta = [
{
"scale_shape": [5],
"scale_dtype": "torch.float32",
"offset_shape": [5],
"offset_dtype": "torch.float32",
"storage_shape": [5, 10],
"storage_dtype": "torch.int8",
"storage_qdtype": "INT8",
}
]
def fake_broadcast_object_list(meta_list, src):
meta_list[0] = fake_meta[0]
mock_dist.broadcast_object_list.side_effect = fake_broadcast_object_list
received_tensors = {}
def fake_broadcast(tensor, src):
received_tensors[id(tensor)] = tensor
mock_broadcast.side_effect = fake_broadcast
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(5, 10))
quantizer.init_weight(weight, bias=None)
_broadcast_quantizer_state(quantizer, owner_rank=0)
assert quantizer.is_quantized is True, "non-owner quantizer should be quantized after broadcast"
assert quantizer.w_q_param is not None, "non-owner quantizer should have q_param"
assert quantizer.w_q_storage is not None, "non-owner quantizer should have q_storage"
@mock.patch("msmodelslim.core.quantizer.impl.ssz.dist")
@mock.patch("msmodelslim.core.quantizer.impl.ssz.broadcast_tensor_process_group_safe")
def test_broadcast_quantizer_state_contiguous_tensors(self, mock_broadcast, mock_dist):
"""验证 owner rank 发送的 tensor 是 contiguous 的"""
mock_dist.get_rank.return_value = 0
mock_dist.is_initialized.return_value = True
mock_dist.get_world_size.return_value = 2
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(8, 16))
quantizer.init_weight(weight, bias=None)
quantizer()
quantizer.w_q_storage = QStorage(
quantizer.w_q_storage.dtype,
quantizer.w_q_storage.value.T.contiguous().T,
)
_broadcast_quantizer_state(quantizer, owner_rank=0)
for call_arg in mock_broadcast.call_args_list:
t = call_arg.args[0]
assert t.is_contiguous(), "broadcast tensor should be contiguous, got shape=%s stride=%s" % (
t.shape,
t.stride(),
)
def test_distributed_sync_override_is_different_from_base(self):
"""验证 WeightPerChannelSsz.distributed_sync 与基类不同(已覆盖)"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer as AWQ
assert WeightPerChannelSsz.distributed_sync is not AWQ.distributed_sync, (
"WeightPerChannelSsz should override distributed_sync"
)
def test_distributed_sync_single_rank_skips_broadcast(self):
"""验证单卡环境下 distributed_sync 不执行 broadcast"""
with mock.patch("msmodelslim.core.quantizer.impl.ssz.dist.is_initialized", return_value=False):
quantizer = WeightPerChannelSsz(self.config)
weight = QStorage(QDType.FLOAT, torch.randn(5, 10))
quantizer.init_weight(weight, bias=None)
quantizer()
from msmodelslim.utils.distributed.task_scheduler.types import TaskExecutionRecord, TaskSyncContext
record = TaskExecutionRecord(task_id="test", executor_rank=0)
sync_ctx = TaskSyncContext(model=quantizer, rank=0, world_size=1)
quantizer.distributed_sync(record, sync_ctx)
def test_distributed_sync_raises_spec_error_when_not_data_free(self):
"""验证非 data-free quantizer 调用 distributed_sync 抛出 SpecError"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer
class _NotDataFreeQuantizer(AutoWeightQuantizer):
def is_data_free(self):
return False
def init_weight(self, weight, bias=None):
pass
def forward(self, x=None):
return torch.empty(0)
def get_q_storage(self):
return None
def get_q_param(self):
return None
q = _NotDataFreeQuantizer()
from msmodelslim.utils.distributed.task_scheduler.types import TaskExecutionRecord, TaskSyncContext
record = TaskExecutionRecord(task_id="test", executor_rank=0)
sync_ctx = TaskSyncContext(model=q, rank=0, world_size=2)
with pytest.raises(SpecError, match="data-free"):
q.distributed_sync(record, sync_ctx)
def test_distributed_sync_data_free_quantizer_does_not_raise(self):
"""验证 data-free quantizer 调用 distributed_sync 不抛出异常"""
from msmodelslim.core.quantizer.base import AutoWeightQuantizer
class _DataFreeQuantizer(AutoWeightQuantizer):
def is_data_free(self):
return True
def init_weight(self, weight, bias=None):
self._weight = weight
def forward(self, x=None):
return torch.empty(0)
def get_q_storage(self):
return None
def get_q_param(self):
return None
q = _DataFreeQuantizer()
from msmodelslim.utils.distributed.task_scheduler.types import TaskExecutionRecord, TaskSyncContext
record = TaskExecutionRecord(task_id="test", executor_rank=0)
sync_ctx = TaskSyncContext(model=q, rank=0, world_size=2)
try:
q.distributed_sync(record, sync_ctx)
except SpecError:
pytest.fail("distributed_sync should not raise SpecError for data-free quantizers")
class TestSszCalculateQparam:
"""测试 ssz_calculate_qparam 函数"""
def setup_class(self):
"""设置测试环境"""
self.symmetric_config = QConfig(dtype="int8", scope="per_channel", method="ssz", symmetric=True)
self.asymmetric_config = QConfig(dtype="int8", scope="per_channel", method="ssz", symmetric=False)
def test_symmetric_quantization(self):
"""测试对称量化"""
from msmodelslim.core.quantizer.impl.ssz import ssz_calculate_qparam
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
initial_q_param = calculate_qparam(
min_val=torch.min(weight.T.value, dim=0)[0],
max_val=torch.max(weight.T.value, dim=0)[0],
q_dtype=QDType(self.symmetric_config.dtype),
q_scope=QScope(self.symmetric_config.scope),
symmetric=self.symmetric_config.symmetric,
)
result_q_param = ssz_calculate_qparam(weight.T, initial_q_param)
assert result_q_param is not None, "q_param is not correct, expected: %s, actual: %s" % (None, result_q_param)
assert result_q_param.scheme == self.symmetric_config.to_scheme(), (
"q_param.scheme is not correct, expected: %s, actual: %s"
% (self.symmetric_config.to_scheme(), result_q_param.scheme)
)
assert "scale" in result_q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"scale",
result_q_param.ext.keys(),
)
assert "offset" in result_q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"offset",
result_q_param.ext.keys(),
)
assert isinstance(result_q_param.ext["scale"], torch.Tensor), (
"q_param.ext['scale'] is not correct, expected: %s, actual: %s"
% (torch.Tensor, type(result_q_param.ext["scale"]))
)
assert isinstance(result_q_param.ext["offset"], torch.Tensor), (
"q_param.ext['offset'] is not correct, expected: %s, actual: %s"
% (torch.Tensor, type(result_q_param.ext["offset"]))
)
assert result_q_param.ext["scale"].shape == (weight.value.shape[0],), (
"q_param.ext['scale'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), result_q_param.ext["scale"].shape)
)
assert result_q_param.ext["offset"].shape == (weight.value.shape[0],), (
"q_param.ext['offset'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), result_q_param.ext["offset"].shape)
)
assert result_q_param.ext["offset"].max() == 0 and result_q_param.ext["offset"].min() == 0, (
"q_param.ext['offset'] is not correct, expected: %s, actual: %s"
% (True, result_q_param.ext["offset"].max() == 0 and result_q_param.ext["offset"].min() == 0)
)
def test_asymmetric_quantization(self):
"""测试非对称量化"""
from msmodelslim.core.quantizer.impl.ssz import ssz_calculate_qparam
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
initial_q_param = calculate_qparam(
min_val=torch.min(weight.T.value, dim=0)[0],
max_val=torch.max(weight.T.value, dim=0)[0],
q_dtype=QDType(self.asymmetric_config.dtype),
q_scope=QScope(self.asymmetric_config.scope),
symmetric=self.asymmetric_config.symmetric,
)
result_q_param = ssz_calculate_qparam(weight.T, initial_q_param)
assert result_q_param is not None, "q_param is not correct, expected: %s, actual: %s" % (None, result_q_param)
assert result_q_param.scheme == self.asymmetric_config.to_scheme(), (
"q_param.scheme is not correct, expected: %s, actual: %s"
% (self.asymmetric_config.to_scheme(), result_q_param.scheme)
)
assert "scale" in result_q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"scale",
result_q_param.ext.keys(),
)
assert "offset" in result_q_param.ext, "q_param.ext is not correct, expected: %s, actual: %s" % (
"offset",
result_q_param.ext.keys(),
)
assert isinstance(result_q_param.ext["scale"], torch.Tensor), (
"q_param.ext['scale'] is not correct, expected: %s, actual: %s"
% (torch.Tensor, type(result_q_param.ext["scale"]))
)
assert isinstance(result_q_param.ext["offset"], torch.Tensor), (
"q_param.ext['offset'] is not correct, expected: %s, actual: %s"
% (torch.Tensor, type(result_q_param.ext["offset"]))
)
assert result_q_param.ext["scale"].shape == (weight.value.shape[0],), (
"q_param.ext['scale'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), result_q_param.ext["scale"].shape)
)
assert result_q_param.ext["offset"].shape == (weight.value.shape[0],), (
"q_param.ext['offset'].shape is not correct, expected: %s, actual: %s"
% ((weight.value.shape[0],), result_q_param.ext["offset"].shape)
)
assert result_q_param.ext["offset"].max() != 0 or result_q_param.ext["offset"].min() != 0, (
"q_param.ext['offset'] is not correct, expected: %s, actual: %s"
% (True, result_q_param.ext["offset"].max() != 0 or result_q_param.ext["offset"].min() != 0)
)
def test_scale_offset_validity(self):
"""测试scale和offset的有效性"""
from msmodelslim.core.quantizer.impl.ssz import ssz_calculate_qparam
weight = QStorage(QDType.FLOAT, torch.randn(10, 20))
initial_q_param = calculate_qparam(
min_val=torch.min(weight.T.value, dim=0)[0],
max_val=torch.max(weight.T.value, dim=0)[0],
q_dtype=QDType(self.symmetric_config.dtype),
q_scope=QScope(self.symmetric_config.scope),
symmetric=self.symmetric_config.symmetric,
)
result_q_param = ssz_calculate_qparam(weight.T, initial_q_param)
assert torch.all(torch.isfinite(result_q_param.ext["scale"])), (
"scale is not correct, expected: %s, actual: %s"
% (True, torch.all(torch.isfinite(result_q_param.ext["scale"])))
)
assert torch.all(result_q_param.ext["scale"] != 0), "scale is not correct, expected: %s, actual: %s" % (
True,
torch.all(result_q_param.ext["scale"] != 0),
)
assert torch.all(torch.isfinite(result_q_param.ext["offset"])), (
"offset is not correct, expected: %s, actual: %s"
% (True, torch.all(torch.isfinite(result_q_param.ext["offset"])))
)