"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
-------------------------------------------------------------------------
"""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from msmodelslim.app.auto_tuning.application import AutoTuningApplication
from msmodelslim.app.auto_tuning.model_info_interface import ModelInfoInterface
from msmodelslim.app.auto_tuning.plan_manager_infra import (
TuningPlanConfig,
TuningPlanManagerInfra,
)
from msmodelslim.app.auto_tuning.practice_accuracy_infra import (
TuningAccuracyInfra,
TuningAccuracyManagerInfra,
)
from msmodelslim.app.auto_tuning.practice_history_infra import (
TuningHistoryInfra,
TuningHistoryManagerInfra,
)
from msmodelslim.app.auto_tuning.practice_manager_infra import PracticeManagerInfra
from msmodelslim.core.quant_service import IQuantService
from msmodelslim.core.tune_strategy import (
ITuningStrategy,
ITuningStrategyFactory,
EvaluateResult,
EvaluateAccuracy,
)
from msmodelslim.model import IModel, IModelFactory
from msmodelslim.utils.exception import SchemaValidateError, SecurityError
def _make_plan():
"""构造 plan:用 MagicMock 代替 strategy/evaluation 避免插件注册检查。"""
plan = MagicMock(spec=TuningPlanConfig)
plan.strategy = MagicMock()
plan.evaluation = MagicMock()
return plan
def _make_eval_result():
return EvaluateResult(
accuracies=[EvaluateAccuracy(dataset="d1", accuracy=0.9)],
is_satisfied=True,
)
def _make_history():
h = MagicMock(spec=TuningHistoryInfra)
return h
def _make_history_manager(history):
m = MagicMock(spec=TuningHistoryManagerInfra)
m.load_history = MagicMock(return_value=history)
return m
def _make_accuracy(count=0, hit_result=None):
a = MagicMock(spec=TuningAccuracyInfra)
a.get_accuracy_count = MagicMock(return_value=count)
a.get_accuracy = MagicMock(return_value=hit_result)
a.append_accuracy = MagicMock()
return a
def _make_accuracy_manager(accuracy):
m = MagicMock(spec=TuningAccuracyManagerInfra)
m.load_accuracy = MagicMock(return_value=accuracy)
return m
def _make_strategy(generator_factory):
s = MagicMock(spec=ITuningStrategy)
s.generate_practice = MagicMock(side_effect=lambda model: generator_factory())
return s
def _make_strategy_factory(strategy):
f = MagicMock(spec=ITuningStrategyFactory)
f.create_strategy = MagicMock(return_value=strategy)
return f
def _make_model_adapter(pedigree=None, model_type_name="Qwen3-32B"):
a = MagicMock(spec=IModel)
a.__class__.__name__ = "FakeAdapter"
a.get_model_pedigree = MagicMock(return_value=pedigree) if pedigree else MagicMock()
a.get_model_type = MagicMock(return_value=model_type_name)
if pedigree:
class _AdapterWithInfo(ModelInfoInterface):
def __init__(self, name, p, t):
self._name = name
self._p = p
self._t = t
def get_model_pedigree(self):
return self._p
def get_model_type(self):
return self._t
@property
def model_type(self):
return self._t
@property
def model_path(self):
return Path("/fake")
@property
def trust_remote_code(self):
return False
return _AdapterWithInfo("FakeAdapter", pedigree, model_type_name)
return a
def _make_model_factory(adapter):
f = MagicMock(spec=IModelFactory)
f.create = MagicMock(return_value=adapter)
return f
def _make_quant_service():
qs = MagicMock(spec=IQuantService)
qs.quantize = MagicMock()
return qs
def _make_eval_service(eval_result=None):
s = MagicMock()
s.evaluate = MagicMock(return_value=eval_result or _make_eval_result())
return s
def _make_practice_manager(supports_save=True, save_record=None):
p = MagicMock(spec=PracticeManagerInfra)
p.is_saving_supported = MagicMock(return_value=supports_save)
p.save_practice = MagicMock(
side_effect=lambda model_pedigree, practice: save_record.append((model_pedigree, practice))
)
return p
def _build_app(
adapter,
plan=None,
history=None,
accuracy=None,
strategy=None,
save_practice_record=None,
practice_manager_supports=True,
):
"""组装一个 AutoTuningApplication,所有依赖都是 mock。"""
if plan is None:
plan = _make_plan()
if history is None:
history = _make_history()
if accuracy is None:
accuracy = _make_accuracy(count=0)
if strategy is None:
def gen():
yield MagicMock(name="practice")
strategy = _make_strategy(gen)
if save_practice_record is None:
save_practice_record = []
return AutoTuningApplication(
plan_manager=_ConcretePlanManager(plan),
practice_manager=_make_practice_manager(practice_manager_supports, save_practice_record),
evaluation_service=_make_eval_service(),
tuning_history_manager=_make_history_manager(history),
tuning_accuracy_manager=_make_accuracy_manager(accuracy),
quantization_service=_make_quant_service(),
model_factory=_make_model_factory(adapter),
strategy_factory=_make_strategy_factory(strategy),
)
class _ConcretePlanManager(TuningPlanManagerInfra):
def __init__(self, plan):
self._plan = plan
def get_plan_by_id(self, plan_id):
return self._plan
class TestAutoTuningApplicationTune:
"""Test suite for AutoTuningApplication.tune — 主流程 + 校验。"""
def test_tune_runs_full_pipeline_when_strategy_signals_stop_after_one_practice(self):
"""主路径:策略 send 一次就 StopIteration,tune 走完成功路径。"""
adapter = _make_model_adapter(pedigree=None)
history = _make_history()
app = _build_app(adapter, history=history)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert history.append_history.call_count >= 1
def test_tune_uses_accuracy_cache_when_record_count_is_positive(self):
"""主路径:当 accuracy_count > 0 时打印"将复用"日志。"""
adapter = _make_model_adapter()
accuracy = _make_accuracy(count=3)
app = _build_app(adapter, accuracy=accuracy)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert accuracy.get_accuracy_count.call_count == 1
def test_tune_logs_fresh_start_when_no_accuracy_records_exist(self):
"""主路径:accuracy_count=0 时走"fresh tuning"日志分支。"""
adapter = _make_model_adapter()
accuracy = _make_accuracy(count=0)
app = _build_app(adapter, accuracy=accuracy)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert accuracy.get_accuracy_count.call_count == 1
def test_tune_persists_practice_when_practice_manager_supports_save_and_adapter_implements_info(self):
"""主路径:adapter 实现了 ModelInfoInterface 且 manager 支持保存,应调用 save_practice。"""
adapter = _make_model_adapter(pedigree="qwen3")
save_record = []
app = _build_app(adapter, save_practice_record=save_record)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert len(save_record) == 1
pedigree, _ = save_record[0]
assert pedigree == "qwen3"
def test_tune_skips_saving_when_practice_manager_disallows_save(self):
"""边界:is_saving_supported=False 时应跳过 save_practice。"""
adapter = _make_model_adapter(pedigree="qwen3")
save_record = []
app = _build_app(
adapter,
save_practice_record=save_record,
practice_manager_supports=False,
)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert not save_record
def test_tune_skips_saving_when_adapter_does_not_implement_model_info_interface(self):
"""边界:adapter 不是 ModelInfoInterface 子类时跳过 save_practice。"""
adapter = _make_model_adapter(pedigree=None)
save_record = []
app = _build_app(adapter, save_practice_record=save_record)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert not save_record
def test_tune_uses_cache_when_accuracy_record_exists_for_practice(self):
"""边界:Path 1 — get_accuracy 返回非 None 时跳过 quantize+evaluate。"""
adapter = _make_model_adapter()
eval_result = _make_eval_result()
accuracy = _make_accuracy(count=1, hit_result=eval_result)
quant = _make_quant_service()
eval_svc = _make_eval_service()
history = _make_history()
def gen():
yield MagicMock(name="practice")
strategy = _make_strategy(gen)
save_record = []
app = AutoTuningApplication(
plan_manager=_ConcretePlanManager(_make_plan()),
practice_manager=_make_practice_manager(True, save_record),
evaluation_service=eval_svc,
tuning_history_manager=_make_history_manager(history),
tuning_accuracy_manager=_make_accuracy_manager(accuracy),
quantization_service=quant,
model_factory=_make_model_factory(adapter),
strategy_factory=_make_strategy_factory(strategy),
)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert quant.quantize.call_count == 0
assert eval_svc.evaluate.call_count == 0
assert accuracy.get_accuracy.call_count >= 1
def test_tune_runs_quantize_and_evaluate_when_cache_misses(self):
"""边界:Path 2 — get_accuracy 返回 None 时调用 quantize + evaluate。"""
adapter = _make_model_adapter()
accuracy = _make_accuracy(count=0, hit_result=None)
quant = _make_quant_service()
eval_svc = _make_eval_service()
history = _make_history()
save_record = []
def gen():
yield MagicMock(name="practice")
strategy = _make_strategy(gen)
app = AutoTuningApplication(
plan_manager=_ConcretePlanManager(_make_plan()),
practice_manager=_make_practice_manager(True, save_record),
evaluation_service=eval_svc,
tuning_history_manager=_make_history_manager(history),
tuning_accuracy_manager=_make_accuracy_manager(accuracy),
quantization_service=quant,
model_factory=_make_model_factory(adapter),
strategy_factory=_make_strategy_factory(strategy),
)
with tempfile.TemporaryDirectory() as save_path:
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
assert quant.quantize.call_count == 1
assert eval_svc.evaluate.call_count == 1
assert accuracy.append_accuracy.call_count == 1
def test_tune_raises_schema_validate_error_when_model_type_is_not_string(self):
"""异常:model_type 非 str 应抛 SchemaValidateError。"""
adapter = _make_model_adapter()
app = _build_app(adapter)
with tempfile.TemporaryDirectory() as save_path:
with pytest.raises(SchemaValidateError):
app.tune(
model_type=123,
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
)
def test_tune_raises_schema_validate_error_when_plan_id_is_not_string(self):
"""异常:plan_id 非 str 应抛 SchemaValidateError。"""
adapter = _make_model_adapter()
app = _build_app(adapter)
with tempfile.TemporaryDirectory() as save_path:
with pytest.raises(SchemaValidateError):
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id=42,
)
def test_tune_raises_schema_validate_error_when_device_indices_contains_non_int(self):
"""异常:device_indices 含非 int 应抛 SchemaValidateError。"""
adapter = _make_model_adapter()
app = _build_app(adapter)
with tempfile.TemporaryDirectory() as save_path:
with pytest.raises(SchemaValidateError):
app.tune(
model_type="qwen3",
model_path=Path(save_path),
save_path=save_path,
plan_id="plan-001",
device_indices=["0", "1"],
)
@pytest.mark.xfail(reason="Flaky: test pollution affects SecurityError path validation when run with full suite")
def test_tune_raises_security_error_when_model_path_does_not_exist(self):
"""异常:model_path 的父目录不存在应抛 SecurityError。"""
import uuid
adapter = _make_model_adapter()
app = _build_app(adapter)
with tempfile.TemporaryDirectory() as save_path:
nonexistent = f"/tmp/ut_no_such_{uuid.uuid4().hex}/model"
with pytest.raises(SecurityError):
app.tune(
model_type="qwen3",
model_path=nonexistent,
save_path=save_path,
plan_id="plan-001",
)