import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import torch
from msmodelslim.core.const import DeviceType
from msmodelslim.model.common.transformers import TransformersModel
from msmodelslim.utils.exception import SchemaValidateError
class DummyConfig:
"""模拟配置对象"""
def __init__(self):
self.model_type = 'DummyModel'
self.num_hidden_layers = 3
class TestTransformersModelLoadConfig(unittest.TestCase):
"""测试TransformersModel的_load_config方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_config_from_pretrained')
def test_load_config_when_called_then_delegate_to_safe_generator(self, mock_get_config):
"""测试_load_config方法:应委托给SafeGenerator"""
mock_config = DummyConfig()
mock_get_config.return_value = mock_config
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
result = adapter._load_config(trust_remote_code=False)
self.assertEqual(result, mock_config)
mock_get_config.assert_called_once_with(
model_path=str(self.model_path),
trust_remote_code=False
)
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_config_from_pretrained')
def test_load_config_with_trust_remote_code_when_called_then_pass_trust_flag(self, mock_get_config):
"""测试_load_config方法:trust_remote_code=True时应正确传递"""
mock_config = DummyConfig()
mock_get_config.return_value = mock_config
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
result = adapter._load_config(trust_remote_code=True)
mock_get_config.assert_called_once_with(
model_path=str(self.model_path),
trust_remote_code=True
)
class TestTransformersModelLoadTokenizer(unittest.TestCase):
"""测试TransformersModel的_load_tokenizer方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_tokenizer_from_pretrained')
def test_load_tokenizer_when_called_then_delegate_to_safe_generator(self, mock_get_tokenizer):
"""测试_load_tokenizer方法:应委托给SafeGenerator"""
mock_tokenizer = MagicMock()
mock_get_tokenizer.return_value = mock_tokenizer
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
result = adapter._load_tokenizer(trust_remote_code=False)
self.assertEqual(result, mock_tokenizer)
mock_get_tokenizer.assert_called_once_with(
model_path=str(self.model_path),
use_fast=False,
legacy=False,
trust_remote_code=False
)
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_tokenizer_from_pretrained')
def test_load_tokenizer_with_trust_remote_code_when_called_then_pass_trust_flag(self, mock_get_tokenizer):
"""测试_load_tokenizer方法:trust_remote_code=True时应正确传递"""
mock_tokenizer = MagicMock()
mock_get_tokenizer.return_value = mock_tokenizer
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
result = adapter._load_tokenizer(trust_remote_code=True)
mock_get_tokenizer.assert_called_once_with(
model_path=str(self.model_path),
use_fast=False,
legacy=False,
trust_remote_code=True
)
class TestTransformersModelLoadModel(unittest.TestCase):
"""测试TransformersModel的_load_model方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_model_from_pretrained')
def test_load_model_with_npu_device_when_called_then_use_auto_device_map(self, mock_get_model):
"""测试_load_model方法:NPU设备时应使用auto device_map"""
mock_model = MagicMock()
mock_get_model.return_value = mock_model
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
adapter.config = DummyConfig()
adapter.trust_remote_code = False
result = adapter._load_model(device=DeviceType.NPU)
self.assertEqual(result, mock_model)
call_kwargs = mock_get_model.call_args[1]
self.assertEqual(call_kwargs['device_map'], 'auto')
@patch('msmodelslim.model.common.transformers.SafeGenerator.get_model_from_pretrained')
def test_load_model_with_cpu_device_when_called_then_use_cpu_device_map(self, mock_get_model):
"""测试_load_model方法:CPU设备时应使用cpu device_map"""
mock_model = MagicMock()
mock_get_model.return_value = mock_model
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.model_path = self.model_path
adapter.config = DummyConfig()
adapter.trust_remote_code = False
result = adapter._load_model(device=DeviceType.CPU)
call_kwargs = mock_get_model.call_args[1]
self.assertEqual(call_kwargs['device_map'], 'cpu')
class TestTransformersModelGetModelType(unittest.TestCase):
"""测试TransformersModel的_get_model_type方法"""
def test_get_model_type_with_none_when_called_then_return_config_model_type(self):
"""测试_get_model_type方法:model_type为None时应返回config中的model_type"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.config = DummyConfig()
adapter.config.model_type = 'ConfigModelType'
result = adapter._get_model_type(None)
self.assertEqual(result, 'ConfigModelType')
def test_get_model_type_with_value_when_called_then_return_input_value(self):
"""测试_get_model_type方法:model_type有值时应返回输入值"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.config = DummyConfig()
result = adapter._get_model_type('CustomModelType')
self.assertEqual(result, 'CustomModelType')
class TestTransformersModelGetModelPedigree(unittest.TestCase):
"""测试TransformersModel的_get_model_pedigree方法"""
def test_get_model_pedigree_with_none_when_called_then_return_config_model_type(self):
"""测试_get_model_pedigree方法:model_type为None时应返回config中的model_type"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.config = DummyConfig()
adapter.config.model_type = 'Qwen2'
result = adapter._get_model_pedigree(None)
self.assertEqual(result, 'Qwen2')
def test_get_model_pedigree_with_valid_name_when_called_then_extract_prefix(self):
"""测试_get_model_pedigree方法:有效名称时应提取前缀并转小写"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.config = DummyConfig()
result = adapter._get_model_pedigree('Llama3-8B')
self.assertEqual(result, 'llama')
result = adapter._get_model_pedigree('DeepSeek-V3')
self.assertEqual(result, 'deepseek')
def test_get_model_pedigree_with_invalid_name_when_called_then_raise_schema_validate_error(self):
"""测试_get_model_pedigree方法:无效名称时应抛出SchemaValidateError"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
adapter.config = DummyConfig()
with self.assertRaises(SchemaValidateError) as context:
adapter._get_model_pedigree('123-invalid')
self.assertIn("Invalid model_name", str(context.exception))
class TestTransformersModelGetTokenizedData(unittest.TestCase):
"""测试TransformersModel的_get_tokenized_data方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
def test_get_tokenized_data_with_non_list_when_called_then_raise_schema_validate_error(self):
"""测试_get_tokenized_data方法:非列表输入时应抛出SchemaValidateError"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
with self.assertRaises(SchemaValidateError) as context:
adapter._get_tokenized_data('not_a_list', DeviceType.CPU)
self.assertIn("calib_list must be a list", str(context.exception))
class TestTransformersModelGetBatchTokenizedData(unittest.TestCase):
"""测试TransformersModel的_get_batch_tokenized_data方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
def test_get_batch_tokenized_data_with_valid_list_when_called_then_return_batched_data(self):
"""测试_get_batch_tokenized_data方法:有效列表时应返回批量数据"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
mock_batch1 = [torch.tensor([[1, 2, 3]])]
mock_batch2 = [torch.tensor([[4, 5, 6]])]
adapter._get_padding_data = MagicMock(side_effect=[mock_batch1, mock_batch2])
calib_list = ['text1', 'text2', 'text3', 'text4']
result = adapter._get_batch_tokenized_data(calib_list, batch_size=2, device=DeviceType.CPU)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 2)
self.assertEqual(adapter._get_padding_data.call_count, 2)
def test_get_batch_tokenized_data_with_non_list_when_called_then_raise_schema_validate_error(self):
"""测试_get_batch_tokenized_data方法:非列表输入时应抛出SchemaValidateError"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
with self.assertRaises(SchemaValidateError) as context:
adapter._get_batch_tokenized_data('not_a_list', batch_size=2, device=DeviceType.CPU)
self.assertIn("calib_list must be a list", str(context.exception))
class TestTransformersModelGetPaddingData(unittest.TestCase):
"""测试TransformersModel的_get_padding_data方法"""
def setUp(self):
"""测试前的准备工作"""
self.model_path = Path('.')
def test_get_padding_data_with_same_length_when_called_then_no_padding(self):
"""测试_get_padding_data方法:相同长度时不需要padding"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
mock_tokenizer = MagicMock()
mock_inputs1 = MagicMock()
mock_inputs1.data = {'input_ids': torch.tensor([[1, 2, 3]])}
mock_inputs2 = MagicMock()
mock_inputs2.data = {'input_ids': torch.tensor([[4, 5, 6]])}
mock_tokenizer.side_effect = [mock_inputs1, mock_inputs2]
adapter.tokenizer = mock_tokenizer
import torch.nn.functional as F_torch
with patch('msmodelslim.model.common.transformers.F', F_torch):
calib_list = ['text1', 'text2']
result = adapter._get_padding_data(calib_list, DeviceType.CPU)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
def test_get_padding_data_with_different_lengths_when_called_then_apply_padding(self):
"""测试_get_padding_data方法:不同长度时应应用padding"""
with patch('msmodelslim.model.common.transformers.TransformersModel.__init__', return_value=None):
adapter = TransformersModel.__new__(TransformersModel)
mock_tokenizer = MagicMock()
mock_inputs1 = MagicMock()
mock_inputs1.data = {'input_ids': torch.tensor([[1, 2]])}
mock_inputs2 = MagicMock()
mock_inputs2.data = {'input_ids': torch.tensor([[4, 5, 6, 7]])}
mock_tokenizer.side_effect = [mock_inputs1, mock_inputs2]
adapter.tokenizer = mock_tokenizer
import torch.nn.functional as F_torch
with patch('msmodelslim.model.common.transformers.F', F_torch):
calib_list = ['short', 'longer text']
result = adapter._get_padding_data(calib_list, DeviceType.CPU)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].shape[1], 4)