from unittest.mock import patch
import pytest
from msmodelslim.utils.dependency_check import DependencyChecker, require_packages, get_require_packages
from msmodelslim.utils.exception import VersionError
def setup_function():
"""所有测试前清空 _requirements"""
DependencyChecker._requirements.clear()
def test_add_requirements():
DependencyChecker.set_plugin("pluginA", {"torch": ">=2.0"})
DependencyChecker.set_plugin("pluginA", {"numpy": ">=1.25"})
assert "pluginA" in DependencyChecker._requirements
assert DependencyChecker._requirements["pluginA"]["torch"] == ">=2.0"
assert DependencyChecker._requirements["pluginA"]["numpy"] == ">=1.25"
@patch("msmodelslim.utils.dependency_check.DependencyChecker._check_single")
def test_check_plugin(mock_check):
DependencyChecker.set_plugin("pluginA", {"torch": ">=2.0", "numpy": "<2.0"})
DependencyChecker.check_plugin("pluginA")
assert mock_check.call_count == 2
mock_check.assert_any_call("torch", ">=2.0")
mock_check.assert_any_call("numpy", "<2.0")
@patch("msmodelslim.utils.dependency_check.metadata_version")
def test_check_single_version_invalid(mock_metadata):
mock_metadata.return_value = "1.0.0"
with pytest.raises(VersionError):
DependencyChecker._check_single("torch", ">=2.0")
def test_require_packages_decorator():
@require_packages({"torch": ">=2.0"})
class MyModel:
pass
assert hasattr(MyModel, "_require_packages")
assert MyModel._require_packages == {"torch": ">=2.0"}
def test_get_require_packages():
@require_packages({"numpy": ">=1.0"})
class MyModel:
pass
pkg = get_require_packages(MyModel)
assert pkg == {"numpy": ">=1.0"}