import unittest
from unittest.mock import MagicMock
import torch
import torch.nn as nn
class TestAnalysisMethods(unittest.TestCase):
"""测试分析方法"""
def test_analysis_target_matcher_get_linear_conv_layers(self):
"""测试AnalysisTargetMatcher的get_linear_conv_layers方法"""
from msmodelslim.core.analysis_service.analysis_methods import AnalysisTargetMatcher
mock_model = MagicMock()
mock_linear = MagicMock(spec=nn.Linear)
mock_conv = MagicMock(spec=nn.Conv2d)
mock_other = MagicMock()
mock_model.named_modules.return_value = [
('linear1', mock_linear),
('conv1', mock_conv),
('other1', mock_other)
]
result = AnalysisTargetMatcher.get_linear_conv_layers(mock_model)
self.assertEqual(result, ['linear1', 'conv1'])
def test_analysis_target_matcher_filter_layers_by_patterns(self):
"""测试AnalysisTargetMatcher的filter_layers_by_patterns方法"""
from msmodelslim.core.analysis_service.analysis_methods import AnalysisTargetMatcher
layer_names = ['layer1.linear', 'layer2.conv', 'layer3.other']
result = AnalysisTargetMatcher.filter_layers_by_patterns(layer_names, ['*'])
self.assertEqual(result, layer_names)
result = AnalysisTargetMatcher.filter_layers_by_patterns(layer_names, ['layer1.*'])
self.assertEqual(result, ['layer1.linear'])
result = AnalysisTargetMatcher.filter_layers_by_patterns(layer_names, [])
self.assertEqual(result, layer_names)
def test_quantile_analysis_method(self):
"""测试QuantileAnalysisMethod"""
from msmodelslim.core.analysis_service.analysis_methods import QuantileAnalysisMethod
method = QuantileAnalysisMethod(sample_step=10)
self.assertEqual(method.name, 'quantile')
layer_data = {
'tensor': [torch.tensor([[1.0, 2.0, 3.0, 4.0]])],
'device': torch.device('cpu')
}
score = method.compute_score(layer_data)
hook = method.get_hook()
self.assertTrue(callable(hook))
def test_quantile_analysis_method_hook_basic_functionality(self):
"""测试QuantileAnalysisMethod.get_hook的基本功能"""
from msmodelslim.core.analysis_service.analysis_methods import QuantileAnalysisMethod
method = QuantileAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('tensor', layer_data)
self.assertIn('device', layer_data)
self.assertEqual(len(layer_data.get('tensor', [])), 1)
self.assertEqual(layer_data.get('device'), input_tensor.device)
def test_quantile_analysis_method_hook_tuple_input(self):
"""测试QuantileAnalysisMethod.get_hook处理tuple输入"""
from msmodelslim.core.analysis_service.analysis_methods import QuantileAnalysisMethod
method = QuantileAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = (torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), torch.tensor([[6.0, 7.0]]))
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('tensor', layer_data)
self.assertEqual(len(layer_data.get('tensor', [])), 1)
def test_quantile_analysis_method_hook_data_accumulation(self):
"""测试QuantileAnalysisMethod.get_hook数据累积行为"""
from msmodelslim.core.analysis_service.analysis_methods import QuantileAnalysisMethod
method = QuantileAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
mock_module = MagicMock()
output_tensor = None
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertEqual(len(layer_data.get('tensor', [])), 1)
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertEqual(len(layer_data.get('tensor', [])), 2)
self.assertEqual(layer_data.get('device'), input_tensor1.device)
def test_quantile_analysis_method_hook_multiple_layers(self):
"""测试QuantileAnalysisMethod.get_hook多层处理"""
from msmodelslim.core.analysis_service.analysis_methods import QuantileAnalysisMethod
method = QuantileAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
mock_module = MagicMock()
output_tensor = None
layer_name1 = 'layer1'
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name1, stats_dict)
layer_name2 = 'layer2'
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0, 7.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name2, stats_dict)
self.assertIn(layer_name1, stats_dict)
self.assertIn(layer_name2, stats_dict)
layer1_data = stats_dict.get(layer_name1, {})
layer2_data = stats_dict.get(layer_name2, {})
self.assertEqual(len(layer1_data.get('tensor', [])), 1)
self.assertEqual(len(layer2_data.get('tensor', [])), 1)
self.assertEqual(layer1_data.get('device'), input_tensor1.device)
self.assertEqual(layer2_data.get('device'), input_tensor2.device)
def test_std_analysis_method(self):
"""测试StdAnalysisMethod"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
self.assertEqual(method.name, 'std')
layer_data = {
't_max': torch.tensor(5.0),
't_min': torch.tensor(1.0),
'std': torch.tensor(2.0)
}
score = method.compute_score(layer_data)
self.assertIsInstance(score, float)
hook = method.get_hook()
self.assertTrue(callable(hook))
def test_std_analysis_method_hook_basic_functionality(self):
"""测试StdAnalysisMethod.get_hook的基本功能"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('shift', layer_data)
self.assertIn('t_max', layer_data)
self.assertIn('t_min', layer_data)
self.assertIn('std', layer_data)
self.assertEqual(layer_data.get('shift', torch.tensor([])).shape, torch.Size([5]))
self.assertIsInstance(layer_data.get('t_max'), torch.Tensor)
self.assertIsInstance(layer_data.get('t_min'), torch.Tensor)
self.assertIsInstance(layer_data.get('std'), torch.Tensor)
def test_std_analysis_method_hook_tuple_input(self):
"""测试StdAnalysisMethod.get_hook处理tuple输入"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = (torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), torch.tensor([[6.0, 7.0]]))
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('shift', layer_data)
self.assertEqual(layer_data.get('shift', torch.tensor([])).shape, torch.Size([5]))
def test_std_analysis_method_hook_data_accumulation(self):
"""测试StdAnalysisMethod.get_hook数据累积行为"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
mock_module = MagicMock()
output_tensor = None
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
first_max = layer_data.get('t_max', torch.tensor(0.0)).clone()
first_min = layer_data.get('t_min', torch.tensor(0.0)).clone()
first_std = layer_data.get('std', torch.tensor(0.0)).clone()
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
second_max = layer_data.get('t_max', torch.tensor(0.0))
second_min = layer_data.get('t_min', torch.tensor(0.0))
second_std = layer_data.get('std', torch.tensor(0.0))
self.assertTrue(second_max >= first_max)
self.assertTrue(second_min <= first_min)
def test_std_analysis_method_hook_multiple_layers(self):
"""测试StdAnalysisMethod.get_hook多层处理"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
hook = method.get_hook()
stats_dict = {}
mock_module = MagicMock()
output_tensor = None
layer_name1 = 'layer1'
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name1, stats_dict)
layer_name2 = 'layer2'
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0, 7.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name2, stats_dict)
self.assertIn(layer_name1, stats_dict)
self.assertIn(layer_name2, stats_dict)
layer1_data = stats_dict.get(layer_name1, {})
layer2_data = stats_dict.get(layer_name2, {})
for _, layer_data in [(layer_name1, layer1_data), (layer_name2, layer2_data)]:
self.assertIn('shift', layer_data)
self.assertIn('t_max', layer_data)
self.assertIn('t_min', layer_data)
self.assertIn('std', layer_data)
def test_std_analysis_method_hook_shift_calculation(self):
"""测试StdAnalysisMethod.get_hook的shift计算"""
from msmodelslim.core.analysis_service.analysis_methods import StdAnalysisMethod
method = StdAnalysisMethod()
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
tensor_max = torch.max(input_tensor.reshape(-1, 5), dim=0)[0]
tensor_min = torch.min(input_tensor.reshape(-1, 5), dim=0)[0]
expected_shift = (tensor_max + tensor_min) / 2
layer_data = stats_dict.get(layer_name, {})
self.assertTrue(torch.allclose(layer_data.get('shift', torch.tensor(0.0)), expected_shift))
def test_kurtosis_analysis_method(self):
"""测试KurtosisAnalysisMethod"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
self.assertEqual(method.name, 'kurtosis')
layer_data = {
'tensor': [torch.tensor([[1.0, 2.0, 3.0, 4.0]])],
'device': torch.device('cpu')
}
score = method.compute_score(layer_data)
hook = method.get_hook()
self.assertTrue(callable(hook))
def test_kurtosis_analysis_method_hook_basic_functionality(self):
"""测试KurtosisAnalysisMethod.get_hook的基本功能"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('tensor', layer_data)
self.assertIn('device', layer_data)
self.assertEqual(len(layer_data.get('tensor', [])), 1)
self.assertEqual(layer_data.get('device'), input_tensor.device)
def test_kurtosis_analysis_method_hook_tuple_input(self):
"""测试KurtosisAnalysisMethod.get_hook处理tuple输入"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = (torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), torch.tensor([[6.0, 7.0]]))
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
self.assertIn(layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertIn('tensor', layer_data)
self.assertEqual(len(layer_data.get('tensor', [])), 1)
def test_kurtosis_analysis_method_hook_data_accumulation(self):
"""测试KurtosisAnalysisMethod.get_hook数据累积行为"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
mock_module = MagicMock()
output_tensor = None
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertEqual(len(layer_data.get('tensor', [])), 1)
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
self.assertEqual(len(layer_data.get('tensor', [])), 2)
self.assertEqual(layer_data.get('device'), input_tensor1.device)
def test_kurtosis_analysis_method_hook_multiple_layers(self):
"""测试KurtosisAnalysisMethod.get_hook多层处理"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
mock_module = MagicMock()
output_tensor = None
layer_name1 = 'layer1'
input_tensor1 = torch.tensor([[1.0, 2.0, 3.0]])
hook(mock_module, input_tensor1, output_tensor, layer_name1, stats_dict)
layer_name2 = 'layer2'
input_tensor2 = torch.tensor([[4.0, 5.0, 6.0, 7.0]])
hook(mock_module, input_tensor2, output_tensor, layer_name2, stats_dict)
self.assertIn(layer_name1, stats_dict)
self.assertIn(layer_name2, stats_dict)
layer1_data = stats_dict.get(layer_name1, {})
layer2_data = stats_dict.get(layer_name2, {})
self.assertEqual(len(layer1_data.get('tensor', [])), 1)
self.assertEqual(len(layer2_data.get('tensor', [])), 1)
self.assertEqual(layer1_data.get('device'), input_tensor1.device)
self.assertEqual(layer2_data.get('device'), input_tensor2.device)
def test_kurtosis_analysis_method_hook_sorting_behavior(self):
"""测试KurtosisAnalysisMethod.get_hook的排序行为"""
from msmodelslim.core.analysis_service.analysis_methods import KurtosisAnalysisMethod
method = KurtosisAnalysisMethod(sample_step=10)
hook = method.get_hook()
stats_dict = {}
layer_name = 'test_layer'
input_tensor = torch.tensor([[3.0, 1.0, 5.0, 2.0, 4.0]])
output_tensor = None
mock_module = MagicMock()
hook(mock_module, input_tensor, output_tensor, layer_name, stats_dict)
layer_data = stats_dict.get(layer_name, {})
stored_tensor = layer_data.get('tensor', [torch.tensor([])])[0]
stored_values = stored_tensor.squeeze().tolist()
self.assertEqual(stored_values, sorted([3.0, 1.0, 5.0, 2.0, 4.0]))
def test_analysis_method_factory(self):
"""测试AnalysisMethodFactory"""
from msmodelslim.core.analysis_service.analysis_methods import AnalysisMethodFactory
method = AnalysisMethodFactory.create_method('std')
self.assertEqual(method.name, 'std')
with self.assertRaises(ValueError):
AnalysisMethodFactory.create_method('invalid_method')
from msmodelslim.core.analysis_service.analysis_methods import LayerAnalysisMethod
class TestMethod(LayerAnalysisMethod):
@property
def name(self):
return 'test'
def get_hook(self):
return lambda: None
def compute_score(self, layer_data):
return 0.0
AnalysisMethodFactory.register_method('test', TestMethod)
method = AnalysisMethodFactory.create_method('test')
self.assertIsInstance(method, TestMethod)
supported = AnalysisMethodFactory.get_supported_methods()
self.assertIn('std', supported)
self.assertIn('test', supported)
def test_kurtosis_function(self):
"""测试kurtosis函数"""
from msmodelslim.core.analysis_service.analysis_methods import kurtosis
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
result = kurtosis(x)
self.assertIsInstance(result, torch.Tensor)
x_2d = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
result = kurtosis(x_2d, dim=0)
self.assertIsInstance(result, torch.Tensor)
if __name__ == '__main__':
unittest.main()