import os
from functools import lru_cache
from pathlib import Path
from typing import List
from unittest.mock import patch, MagicMock
import torch
from resources.fake_llama.fake_llama import get_fake_llama_model_and_tokenizer
from torch import nn
from transformers import PretrainedConfig, PreTrainedTokenizerBase
from msmodelslim.core.const import DeviceType
from msmodelslim.model.qwen3.model_adapter import Qwen3ModelAdapter
from msmodelslim.processor.kv_smooth import KVSmoothFusedUnit, KVSmoothFusedType
@lru_cache(maxsize=1)
def is_npu_available():
try:
import torch_npu
return torch.npu.is_available()
except ImportError:
return False
@lru_cache(maxsize=1)
def is_cuda_available():
try:
return torch.cuda.is_available()
except ImportError:
return False
class FakeLlamaModelAdapter(Qwen3ModelAdapter):
def __init__(self, model_type: str, model_path: Path, trust_remote_code: bool = False):
model, tokenizer = get_fake_llama_model_and_tokenizer()
self.loaded_config = model.config
self.loaded_model = model
self.loaded_tokenizer = tokenizer
Qwen3ModelAdapter.__init__(self, model_type, model_path, trust_remote_code)
def _load_config(self, trust_remote_code=False) -> PretrainedConfig:
return self.loaded_config
def _load_model(self, device: DeviceType) -> nn.Module:
return self.loaded_model
def _load_tokenizer(self, trust_remote_code=False) -> PreTrainedTokenizerBase:
return self.loaded_tokenizer
def get_kvcache_smooth_fused_subgraph(self) -> List[KVSmoothFusedUnit]:
return [
KVSmoothFusedUnit(
attention_name=f"model.layers.{i}.self_attn",
layer_idx=i,
fused_from_query_states_name="q_proj",
fused_from_key_states_name="k_proj",
fused_type=KVSmoothFusedType.StateViaRopeToLinear
)
for i in range(self.config.num_hidden_layers)
]
def invoke_test(config_name: str, model_save_path: str, device: str = 'cpu', offload_device: str = 'cpu'):
"""使用真正的CLI parser来模拟命令行参数并返回model_adapter"""
import sys
from msmodelslim.cli.__main__ import main as cli_main
original_argv = sys.argv.copy()
captured_model_adapter = None
fake_ep = MagicMock()
fake_ep.name = "fake_llama"
fake_ep.load.return_value = FakeLlamaModelAdapter
try:
config_path = os.path.join(os.path.dirname(__file__), "configs", config_name)
sys.argv = [
'msmodelslim',
'quant',
'--model_type', 'fake_llama',
'--model_path', './',
'--save_path', model_save_path,
'--device', device,
'--config_path', config_path,
'--trust_remote_code', 'False'
]
with (patch(
"msmodelslim.model.plugin_factory.entry_points"
) as mock_entry_points, patch(
"msmodelslim.core.quant_service.modelslim_v1.save.ascendv1.copy_files"
) as mock_copy_files, patch(
"msmodelslim.model.plugin_factory.DependencyChecker.check_plugin"
) as mock_check_plugin):
mock_entry_points.return_value.select.return_value = [fake_ep]
mock_check_plugin.return_value = None
from msmodelslim.core.quant_service.proxy import QuantServiceProxy
original_quantize = QuantServiceProxy.quantize
def capture_model_adapter(
self,
quant_config,
model_adapter,
save_path=None,
device=None,
device_indices=None):
nonlocal captured_model_adapter
captured_model_adapter = model_adapter
return original_quantize(self, quant_config, model_adapter, save_path, device, device_indices)
QuantServiceProxy.quantize = capture_model_adapter
from msmodelslim.core.runner.layer_wise_runner import LayerWiseRunner
original_init = LayerWiseRunner.__init__
def mock_init(self, adapter, offload_device=offload_device):
original_init(self, adapter, offload_device)
LayerWiseRunner.__init__ = mock_init
try:
cli_main()
finally:
QuantServiceProxy.quantize = original_quantize
LayerWiseRunner.__init__ = original_init
finally:
sys.argv = original_argv
return captured_model_adapter
def invoke_analysis_test(metrics: str = "kurtosis", patterns: list = None, topk: int = 15):
"""
使用真正的CLI parser来模拟分析模块命令行参数并返回分析结果
Args:
metrics: 分析算法
patterns: 层模式列表
topk: 输出topk敏感层
Returns:
分析结果
"""
import sys
from msmodelslim.cli.__main__ import main as cli_main
original_argv = sys.argv.copy()
captured_result = None
try:
sys.argv = [
'msmodelslim',
'analyze',
'--model_type', 'fake_llama',
'--model_path', './',
'--device', 'cpu',
'--metrics', metrics,
'--calib_dataset', 'boolq.jsonl',
'--topk', str(topk),
'--trust_remote_code', 'False'
]
if patterns:
sys.argv.extend(['--pattern'])
sys.argv.extend(patterns)
with patch('msmodelslim.cli.analysis.__main__.LayerAnalysisApplication') as analysis_app:
mock_app_instance = MagicMock()
mock_app_instance.analyze.return_value = "mock_analysis_result"
analysis_app.return_value = mock_app_instance
with patch('msmodelslim.cli.analysis.__main__.FileDatasetLoader'):
with patch('msmodelslim.cli.analysis.__main__.LayerSelectorAnalysisService'):
from msmodelslim.cli.analysis.__main__ import main as analysis_main
original_analysis_main = analysis_main
def capture_result(args):
nonlocal captured_result
try:
captured_result = original_analysis_main(args)
return captured_result
except Exception as e:
captured_result = e
return None
import msmodelslim.cli.analysis.__main__ as analysis_module
analysis_module.main = capture_result
try:
cli_main()
finally:
analysis_module.main = original_analysis_main
finally:
sys.argv = original_argv
return captured_result