"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import pytest
import os
from unittest.mock import Mock, patch
from msmodelslim.utils.cache.pth import (
load_cached_data, InputCapture, DumperManager, to_device
)
from msmodelslim.utils.exception import SchemaValidateError
class TestLoadCachedData:
"""测试load_cached_data函数"""
@pytest.fixture
def mock_pth_file_path(self):
"""创建模拟的PTH文件路径"""
return os.path.join("test", "cache", "calib_data.pth")
@pytest.fixture
def mock_generate_func(self):
"""创建模拟的生成函数"""
return Mock()
@pytest.fixture
def mock_model(self):
"""创建模拟的模型"""
return Mock()
@pytest.fixture
def mock_dump_config(self):
"""创建模拟的dump配置"""
config = Mock()
config.capture_mode = "args"
return config
def test_load_cached_data_file_exists(self, mock_pth_file_path, mock_generate_func, mock_model, mock_dump_config):
"""测试缓存文件存在时的加载"""
with patch('os.path.exists', return_value=True):
with patch('msmodelslim.utils.cache.pth.safe_torch_load') as mock_load:
mock_data = {"test": "data"}
mock_load.return_value = mock_data
with patch('msmodelslim.utils.cache.pth.get_valid_read_path') as mock_valid_path:
mock_valid_path.return_value = mock_pth_file_path
with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
result = load_cached_data(
mock_pth_file_path,
mock_generate_func,
mock_model,
mock_dump_config
)
assert result == mock_data
mock_load.assert_called_once_with(mock_pth_file_path)
def test_load_cached_data_file_not_exists(self, mock_pth_file_path, mock_generate_func, mock_model,
mock_dump_config):
"""测试缓存文件不存在时的处理"""
with patch('os.path.exists', return_value=False):
with patch('msmodelslim.utils.cache.pth.DumperManager') as mock_dumper_class:
mock_dumper = Mock()
mock_dumper_class.return_value = mock_dumper
with patch('msmodelslim.utils.cache.pth.safe_torch_load') as mock_load:
mock_data = {"generated": "data"}
mock_load.return_value = mock_data
with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
result = load_cached_data(
mock_pth_file_path,
mock_generate_func,
mock_model,
mock_dump_config
)
assert result == mock_data
mock_generate_func.assert_called_once()
mock_dumper_class.assert_called_once_with(mock_model, capture_mode="args")
mock_dumper.save.assert_called_once_with(mock_pth_file_path)
class TestInputCapture:
"""测试InputCapture类"""
def test_input_capture_reset(self):
"""测试InputCapture的reset方法"""
InputCapture.add_record({"test": "data"})
assert len(InputCapture.get_all()) > 0
InputCapture.reset()
assert len(InputCapture.get_all()) == 0
def test_input_capture_add_and_get_record(self):
"""测试InputCapture的add_record和get_all方法"""
InputCapture.reset()
test_record = {"test_key": "test_value"}
InputCapture.add_record(test_record)
all_records = InputCapture.get_all()
assert len(all_records) == 1
assert all_records[0] == test_record
def test_input_capture_capture_forward_inputs_args_mode(self):
"""测试InputCapture的capture_forward_inputs方法(args模式)"""
InputCapture.reset()
def test_function(arg1, arg2, kwarg1="default"):
return arg1 + arg2
wrapped_func = InputCapture.capture_forward_inputs(test_function, capture_mode="args")
result = wrapped_func(10, 20, kwarg1="custom")
assert result == 30
captured = InputCapture.get_all()
assert len(captured) == 1
assert len(captured[0]) >= 2
def test_input_capture_capture_forward_inputs_method(self):
"""测试InputCapture的capture_forward_inputs方法(方法调用)"""
InputCapture.reset()
class TestClass:
def test_method(self, arg1, arg2):
return arg1 + arg2
wrapped_method = InputCapture.capture_forward_inputs(TestClass.test_method, capture_mode="args")
obj = TestClass()
result = wrapped_method(obj, 15, 25)
assert result == 40
captured = InputCapture.get_all()
assert len(captured) == 1
assert len(captured[0]) >= 2
def test_input_capture_capture_forward_inputs_invalid_mode(self):
"""测试InputCapture的capture_forward_inputs方法使用无效模式"""
def test_function():
pass
wrapped_func = InputCapture.capture_forward_inputs(test_function, capture_mode="invalid_mode")
assert wrapped_func is not None
assert callable(wrapped_func)
class TestDumperManager:
"""测试DumperManager类"""
@pytest.fixture
def mock_module(self):
"""创建模拟的模块"""
return Mock()
@pytest.fixture
def mock_dump_config(self):
"""创建模拟的dump配置"""
config = Mock()
config.capture_mode = "args"
return config
def test_dumper_manager_initialization(self, mock_module, mock_dump_config):
"""测试DumperManager的初始化"""
dumper = DumperManager(mock_module, capture_mode="args")
assert dumper.module is mock_module
assert dumper.capture_mode == "args"
assert dumper.old_forward is not None
def test_dumper_manager_initialization_invalid_capture_mode(self, mock_module):
"""测试DumperManager使用无效capture_mode时的初始化"""
with pytest.raises(SchemaValidateError, match="Invalid capture_mode: 'invalid_mode'"):
DumperManager(mock_module, capture_mode="invalid_mode")
def test_dumper_manager_save(self, mock_module, mock_dump_config):
"""测试DumperManager的save方法"""
dumper = DumperManager(mock_module, capture_mode="args")
InputCapture.add_record({"test": "data"})
with patch('msmodelslim.utils.cache.pth.torch.save') as mock_torch_save:
with patch('msmodelslim.utils.cache.pth.get_logger') as mock_logger:
result = dumper.save("/test/output.pth")
mock_torch_save.assert_called_once()
assert dumper.old_forward is None
def test_dumper_manager_reset(self, mock_module, mock_dump_config):
"""测试DumperManager的reset方法"""
dumper = DumperManager(mock_module, capture_mode="args")
InputCapture.add_record({"test": "data"})
assert len(InputCapture.get_all()) > 0
dumper.reset()
assert len(InputCapture.get_all()) == 0
def test_dumper_manager_add_hook(self, mock_module, mock_dump_config):
"""测试DumperManager的_add_hook方法"""
dumper = DumperManager(mock_module, capture_mode="args")
assert mock_module.forward != dumper.old_forward
assert hasattr(mock_module.forward, '__wrapped__')
class TestToDevice:
"""测试to_device函数"""
@pytest.fixture
def mock_torch(self):
"""Mock torch库"""
with patch('msmodelslim.utils.cache.pth.torch') as mock_torch:
mock_torch.Tensor = Mock
yield mock_torch
def test_to_device_dict(self, mock_torch):
"""测试to_device处理字典类型数据"""
test_dict = {"key1": Mock(), "key2": Mock()}
with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
mock_to_device.return_value = "device_data"
result = to_device(test_dict, "cpu")
assert mock_to_device.call_count >= 2
def test_to_device_list(self, mock_torch):
"""测试to_device处理列表类型数据"""
test_list = [Mock(), Mock()]
with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
mock_to_device.return_value = "device_data"
result = to_device(test_list, "cpu")
assert mock_to_device.call_count >= 2
def test_to_device_tuple(self, mock_torch):
"""测试to_device处理元组类型数据"""
test_tuple = (Mock(), Mock())
with patch('msmodelslim.utils.cache.pth.to_device') as mock_to_device:
mock_to_device.return_value = "device_data"
result = to_device(test_tuple, "cpu")
assert mock_to_device.call_count >= 2
def test_to_device_tensor(self, mock_torch):
"""测试to_device处理张量类型数据"""
mock_tensor = Mock()
mock_tensor.to.return_value = "moved_tensor"
result = to_device(mock_tensor, "cpu")
assert result == "moved_tensor"
mock_tensor.to.assert_called_once_with("cpu")
def test_to_device_other_types(self, mock_torch):
"""测试to_device处理其他类型数据"""
test_data = "string_data"
result = to_device(test_data, "cpu")
assert result == test_data
def test_to_device_recursion_depth_limit(self, mock_torch):
"""测试to_device的递归深度限制"""
deep_data = {}
current = deep_data
for i in range(25):
current["nested"] = {}
current = current["nested"]
with pytest.raises(RecursionError, match="Maximum recursion depth 20 exceeded"):
to_device(deep_data, "cpu")