"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
-------------------------------------------------------------------------
"""
from unittest.mock import patch
from msmodelslim.cli.__main__ import _normalize_analyze_argv, main
class TestNormalizeAnalyzeArgv:
"""Test suite for _normalize_analyze_argv — analyze 子命令的 argv 规范化。"""
def test_normalize_returns_unchanged_when_no_analyze_subcommand(self):
"""主路径:非 analyze 命令应原样返回。"""
argv = ["quant", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert result == argv
def test_normalize_returns_unchanged_when_empty_argv(self):
"""边界:空 argv 应原样返回。"""
assert _normalize_analyze_argv([]) == []
def test_normalize_injects_linear_scope_when_scope_omitted(self):
"""主路径:analyze 不带 scope 时应自动注入 'linear'。"""
argv = ["analyze", "--model_type", "qwen3"]
result = _normalize_analyze_argv(argv)
assert result[0] == "analyze"
assert result[1] == "linear"
assert "--model_type" in result
def test_normalize_returns_unchanged_when_explicit_scope_layer(self):
"""主路径:scope=layer 时应原样保留。"""
argv = ["analyze", "layer", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert result == argv
def test_normalize_returns_unchanged_when_explicit_scope_attn(self):
"""主路径:scope=attn 时应原样保留。"""
argv = ["analyze", "attn", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert result == argv
def test_normalize_returns_unchanged_when_explicit_scope_linear(self):
"""主路径:scope=linear 时应原样保留。"""
argv = ["analyze", "linear", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert result == argv
def test_normalize_returns_unchanged_when_help_requested(self):
"""边界:analyze -h 时不应自动注入 scope(保留 help 干净)。"""
argv = ["analyze", "-h"]
result = _normalize_analyze_argv(argv)
assert result == argv
assert "linear" not in result
def test_normalize_returns_unchanged_when_help_requested_long_form(self):
"""边界:analyze --help 时也不应注入 scope。"""
argv = ["analyze", "--help"]
result = _normalize_analyze_argv(argv)
assert result == argv
def test_normalize_converts_legacy_attention_mse_to_attn_scope(self):
"""主路径:`--metrics attention_mse` 应转为 `attn --metrics mse`。"""
argv = ["analyze", "--model_type", "qwen3", "--metrics", "attention_mse"]
result = _normalize_analyze_argv(argv)
assert "attn" in result
assert "attention_mse" not in result
assert result[result.index("--metrics") + 1] == "mse"
def test_normalize_drops_pattern_arg_when_converting_legacy_attention_mse(self):
"""边界:legacy attention_mse → attn 转换时,attn scope 不接受 --pattern,应被丢弃。"""
argv = ["analyze", "--metrics", "attention_mse", "--pattern", "layer.*", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert "--pattern" not in result
assert "layer.*" not in result
def test_normalize_does_not_modify_args_before_analyze(self):
"""边界:analyze 之前的参数应原样保留。"""
argv = ["--config", "/etc/conf", "analyze", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert "--config" in result
assert "/etc/conf" in result
assert "linear" in result
def test_normalize_returns_unchanged_when_metrics_value_malformed(self):
"""边界:--metrics 后无值(或 ValueError)时不转换。"""
argv = ["analyze", "--metrics", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert "attn" not in result
assert "linear" in result
def test_normalize_returns_unchanged_when_metrics_not_attention_mse(self):
"""边界:--metrics 取非 legacy 值时不应转换。"""
argv = ["analyze", "--metrics", "kurtosis", "--model_type", "x"]
result = _normalize_analyze_argv(argv)
assert "attn" not in result
assert "linear" in result
assert "kurtosis" in result
class TestMainDispatcher:
"""Test suite for main() — 顶层 CLI dispatcher(按子命令路由)。"""
@patch("msmodelslim.cli.__main__.sys")
@patch("msmodelslim.cli.naive_quantization.__main__.main")
def test_main_dispatches_to_naive_quantization_when_command_is_quant(self, mock_nq_main, mock_sys):
"""主路径:command=quant 时应调用 cli.naive_quantization.__main__.main。"""
mock_sys.argv = ["msmodelslim", "quant", "--model_type", "qwen3", "--model_path", "/x", "--save_path", "/y"]
mock_nq_main.return_value = None
main()
mock_nq_main.assert_called_once()
@patch("msmodelslim.cli.__main__.sys")
@patch("msmodelslim.cli.analysis.__main__.main")
def test_main_dispatches_to_analysis_when_command_is_analyze(self, mock_a_main, mock_sys):
"""主路径:command=analyze 时应调用 cli.analysis.__main__.main。"""
mock_sys.argv = [
"msmodelslim",
"analyze",
"linear",
"--model_type",
"qwen3",
"--model_path",
"/x",
]
mock_a_main.return_value = None
main()
mock_a_main.assert_called_once()
@patch("msmodelslim.cli.__main__.sys")
@patch("msmodelslim.cli.auto_tuning.__main__.main")
def test_main_dispatches_to_auto_tuning_when_command_is_tune(self, mock_t_main, mock_sys):
"""主路径:command=tune 时应调用 cli.auto_tuning.__main__.main。"""
mock_sys.argv = [
"msmodelslim",
"tune",
"--model_type",
"qwen3",
"--model_path",
"/x",
"--save_path",
"/y",
"--config",
"/p",
]
mock_t_main.return_value = None
main()
mock_t_main.assert_called_once()
@patch("msmodelslim.cli.__main__.sys")
def test_main_prints_help_when_no_command_given(self, mock_sys):
"""边界:未指定子命令时应正常退出(调用 parser.print_help(),不抛错)。"""
mock_sys.argv = ["msmodelslim"]
main()