"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
-------------------------------------------------------------------------
"""
import os
import shutil
import tempfile
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch import nn
from msmodelslim.processor.flat_quant.flat_quant import (
FlatQuantProcessor,
FlatQuantProcessorConfig,
npu_available,
)
from msmodelslim.processor.flat_quant.flat_quant_interface import FlatQuantInterface
from msmodelslim.core.base.protocol import BatchProcessRequest
from msmodelslim.core.quantizer.linear import LinearQConfig
from msmodelslim.core.quantizer.base import QConfig
from msmodelslim.ir.qal import QDType, QScope
def _make_qconfig():
"""最小可用的 LinearQConfig(floating 模式不真正量化)。"""
return LinearQConfig(weight=QConfig(dtype=QDType.FLOAT, scope=QScope.PER_TENSOR, symmetric=True, method="none"))
class _FlatQuantSubgraphAdapter(FlatQuantInterface):
"""最小 FlatQuantInterface 实现:返回空 subgraph(不注册任何 PAI,验证流程通畅即可)。"""
def get_flatquant_subgraph(self):
return [
{
"source": "layers",
"targets": ["layers"],
"pair_class": None,
"extra_config": {},
}
]
class _ModelWithTieWeights(nn.Module):
"""带 tie_weights 的最小模型(FlatQuantProcessor.post_run 依赖此方法)。"""
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(8, 8) for _ in range(2)])
def tie_weights(self):
"""模拟 PreTrainedModel.tie_weights。"""
pass
def _make_request(module, name, datas, outputs):
"""构造 BatchProcessRequest,避开真实 BatchProcessRequest 复杂依赖。"""
req = MagicMock(spec=BatchProcessRequest)
req.name = name
req.module = module
req.datas = datas
req.outputs = outputs
return req
def _make_processor(model, config):
"""构造 FlatQuantProcessor,mock adapter 以免要求真实模型注册。"""
adapter = _FlatQuantSubgraphAdapter()
with patch("msmodelslim.processor.flat_quant.flat_quant.LayerTrainer", autospec=True):
proc = FlatQuantProcessor(model=model, config=config, adapter=adapter)
return proc
@pytest.mark.smoke
@pytest.mark.xfail(reason="Full pipeline requires real Norm+Linear model + real PAI class, not a mock adapter")
@pytest.mark.parametrize(
"test_device, test_dtype",
[
pytest.param("cpu", torch.float32),
pytest.param("npu", torch.float16, marks=pytest.mark.skipif(not npu_available, reason="NPU not available")),
pytest.param("npu", torch.bfloat16, marks=pytest.mark.skipif(not npu_available, reason="NPU not available")),
],
)
def test_flat_quant_processor_runs_full_pipeline_when_config_is_minimal(test_device, test_dtype):
"""Smoke:FlatQuantProcessor 应能完整跑过 preprocess→process→postprocess,模型结构正确转换。"""
tmp_dir = tempfile.mkdtemp()
try:
model = _ModelWithTieWeights()
config = FlatQuantProcessorConfig(
type="flatquant",
include=["*"],
exclude=[],
)
model = model.to(dtype=test_dtype)
mock_trainer = MagicMock()
mock_trainer.train_layer = MagicMock(return_value=[])
with patch(
"msmodelslim.processor.flat_quant.flat_quant.LayerTrainer",
return_value=mock_trainer,
):
proc = FlatQuantProcessor(model=model, config=config, adapter=_FlatQuantSubgraphAdapter())
proc.pre_run()
target_module = model.layers[0]
dummy_input = torch.zeros(1, 8, dtype=test_dtype)
req = _make_request(
module=target_module,
name="layers.0",
datas=[[(dummy_input,)]],
outputs=[(torch.zeros(1, 8, dtype=test_dtype),)],
)
proc.preprocess(req)
proc.process(req)
proc.postprocess(req)
proc.post_run()
assert model is not None
finally:
if os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
@pytest.mark.smoke
def test_flat_quant_processor_config_validates_unsupported_yaml_fields():
"""Smoke:FlatQuantProcessorConfig 应在显式构造时只接受 init=True 字段(其余从 YAML 来)。"""
cfg = FlatQuantProcessorConfig(type="flatquant", include=["*"], exclude=[])
assert cfg.type == "flatquant"
assert cfg.include == ["*"]
@pytest.mark.smoke
def test_flat_quant_processor_need_kv_cache_returns_false():
"""Smoke:FlatQuantProcessor.need_kv_cache() 应返回 False。"""
model = _ModelWithTieWeights()
config = FlatQuantProcessorConfig(type="flatquant")
proc = FlatQuantProcessor(model=model, config=config, adapter=_FlatQuantSubgraphAdapter())
assert proc.need_kv_cache() is False
@pytest.mark.smoke
def test_flat_quant_processor_post_run_calls_tie_weights_when_invoked():
"""Smoke:post_run 应调用 model.tie_weights()。"""
model = _ModelWithTieWeights()
model.tie_weights = MagicMock()
config = FlatQuantProcessorConfig(type="flatquant")
proc = FlatQuantProcessor(model=model, config=config, adapter=_FlatQuantSubgraphAdapter())
proc.post_run()
model.tie_weights.assert_called_once()
@pytest.mark.smoke
def test_flat_quant_processor_pre_run_sets_model_to_eval_mode():
"""Smoke:pre_run 应将 model 设为 eval 模式。"""
model = _ModelWithTieWeights()
model.train()
config = FlatQuantProcessorConfig(type="flatquant")
proc = FlatQuantProcessor(model=model, config=config, adapter=_FlatQuantSubgraphAdapter())
proc.pre_run()
assert not model.training