import unittest
from typing import Tuple, Dict, Optional
from unittest.mock import Mock
import torch
from resources.fake_llama.fake_llama import get_fake_llama_model_and_tokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer
from msmodelslim.core.base.protocol import ProcessRequest
from msmodelslim.core.const import DeviceType
from msmodelslim.core.runner.generated_runner import GeneratedRunner
from msmodelslim.core.runner.pipeline_parallel_runner import PPRunner
from msmodelslim.processor import AutoProcessorConfig
KEY_INPUT_IDS = "input_ids"
KEY_ATTENTION_MASK = "attention_mask"
STR_TEST_PROMPT = "Hello world"
RETURN_TENSOR_TYPE = "pt"
class TestProcessorBase(unittest.TestCase):
"""处理器测试基类,提供通用的测试方法和工具"""
@staticmethod
def init_model() -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""初始化测试模型和分词器"""
return get_fake_llama_model_and_tokenizer()
def setUp(self):
"""测试前的准备工作"""
self.model, self.tokenizer = self.init_model()
self.original_state_dict = self.model.state_dict().copy()
def tearDown(self):
"""测试后的清理工作"""
del self.model
del self.tokenizer
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def run_processor_with_cfg(self, config: AutoProcessorConfig) -> GeneratedRunner:
"""使用配置运行处理器"""
test_prompt = self.tokenizer(STR_TEST_PROMPT, return_tensors=RETURN_TENSOR_TYPE, padding=True, truncation=True)
dataset_calib = [[test_prompt[KEY_INPUT_IDS], test_prompt[KEY_ATTENTION_MASK]]]
mock_adapter = Mock()
mock_adapter.model = self.model
mock_adapter.init_model.return_value = self.model
mock_adapter.handle_dataset.return_value = dataset_calib
def create_model_visit_generator():
yield ProcessRequest(name="", module=self.model, args=(), kwargs={})
def create_model_forward_generator():
yield ProcessRequest(
name="",
module=self.model,
args=(),
kwargs={
KEY_INPUT_IDS: test_prompt[KEY_INPUT_IDS],
KEY_ATTENTION_MASK: test_prompt[KEY_ATTENTION_MASK],
},
)
mock_adapter.generate_model_visit.return_value = create_model_visit_generator()
mock_adapter.generate_model_forward.return_value = create_model_forward_generator()
mock_adapter.enable_kv_cache.return_value = None
runner = PPRunner(mock_adapter)
runner.add_processor(config)
runner.run(model=self.model, calib_data=dataset_calib, device=DeviceType.CPU)
return runner
def create_test_input(self, text: str = "Hello world", max_length: int = 10) -> Dict[str, torch.Tensor]:
"""创建测试输入"""
inputs = self.tokenizer(
text,
return_tensors="pt",
max_length=max_length,
padding="max_length",
truncation=True
)
return inputs
def run_model_forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""运行模型前向传播"""
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.logits
def assert_model_output_shape(self, outputs: torch.Tensor, expected_shape: Tuple[int, ...]):
"""断言模型输出形状"""
self.assertEqual(outputs.shape, expected_shape)
def assert_model_output_dtype(self, outputs: torch.Tensor, expected_dtype: torch.dtype):
"""断言模型输出数据类型"""
self.assertEqual(outputs.dtype, expected_dtype)
def assert_model_parameters_changed(self, original_state_dict: Dict[str, torch.Tensor]):
"""断言模型参数发生了变化"""
current_state_dict = self.model.state_dict()
for key in original_state_dict:
if key in current_state_dict:
self.assertFalse(
torch.allclose(original_state_dict[key], current_state_dict[key]),
f"Parameter {key} should have changed"
)
def assert_model_parameters_unchanged(self, original_state_dict: Dict[str, torch.Tensor]):
"""断言模型参数没有发生变化"""
current_state_dict = self.model.state_dict()
for key in original_state_dict:
if key in current_state_dict:
self.assertTrue(
torch.allclose(original_state_dict[key], current_state_dict[key]),
f"Parameter {key} should not have changed"
)
def assert_linear_layers_quantized(self, model: PreTrainedModel, expected_quantized_layers: list):
"""断言指定的线性层被量化了"""
for layer_name in expected_quantized_layers:
layer = self.get_module_by_name(model, layer_name)
self.assertIsNotNone(layer, f"Layer {layer_name} not found")
self.assertNotIsInstance(layer, torch.nn.Linear, f"Layer {layer_name} should be quantized")
def assert_linear_layers_not_quantized(self, model: PreTrainedModel, expected_unquantized_layers: list):
"""断言指定的线性层没有被量化"""
for layer_name in expected_unquantized_layers:
layer = self.get_module_by_name(model, layer_name)
self.assertIsNotNone(layer, f"Layer {layer_name} not found")
self.assertIsInstance(layer, torch.nn.Linear, f"Layer {layer_name} should not be quantized")
def get_module_by_name(self, model: PreTrainedModel, module_name: str) -> Optional[torch.nn.Module]:
"""根据名称获取模块"""
names = module_name.split('.')
current_module = model
for name in names:
if hasattr(current_module, name):
current_module = getattr(current_module, name)
else:
return None
return current_module
def count_linear_layers(self, model: PreTrainedModel) -> int:
"""统计模型中的线性层数量"""
count = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
count += 1
return count
def get_linear_layer_names(self, model: PreTrainedModel) -> list:
"""获取模型中所有线性层的名称"""
linear_names = []
for name, module in model.named_modules():
if 'lm_head' in name:
continue
if isinstance(module, torch.nn.Linear):
linear_names.append(name)
return linear_names
def assert_model_runs_without_error(self, inputs: Dict[str, torch.Tensor]):
"""断言模型能够正常运行而不出错"""
try:
outputs = self.run_model_forward(inputs)
self.assertIsInstance(outputs, torch.Tensor)
except Exception as e:
self.fail(f"Model should run without error, but got: {e}")
def assert_outputs_close(self, outputs1: torch.Tensor, outputs2: torch.Tensor, rtol: float = 1e-3,
atol: float = 1e-3):
"""断言两个输出张量接近"""
self.assertTrue(
torch.allclose(outputs1, outputs2, rtol=rtol, atol=atol),
f"Outputs should be close with rtol={rtol}, atol={atol}"
)
def assert_outputs_not_close(self, outputs1: torch.Tensor, outputs2: torch.Tensor, rtol: float = 1e-3,
atol: float = 1e-3):
"""断言两个输出张量不接近"""
self.assertFalse(
torch.allclose(outputs1, outputs2, rtol=rtol, atol=atol),
f"Outputs should not be close with rtol={rtol}, atol={atol}"
)