import os
import queue
import unittest
import sys
from unittest.mock import MagicMock, patch
import numpy as np
from ddt import ddt, data, unpack
from mindie_llm.utils.env import ENV
from mindie_llm.utils.status import MindieLlmStatusCode
from mindie_llm.text_generator.generator import Generator, PDInterface, PDModelConfig, STANDARD_TAG
from mindie_llm.text_generator.utils.generation_output import GenerationOutput
from mindie_llm.text_generator.utils.request import Request
from mindie_llm.text_generator.utils.input_metadata import InputMetadata, SAMPLING_DTYPE
from mindie_llm.text_generator.adapter import generator_torch
from mindie_llm.text_generator.utils.generation_metadata import GenerationParams
from mindie_llm.connector.common.model_execute_data_pb2 import LoraOperationStatus
from mindie_llm.modeling.model_wrapper.model_info import ModelInfo
from mindie_llm.utils.log.error_code import ErrorCode, ErrorCodeException
from tests.pythontest.npu import FakeModelRunner, FakeModelWrapper, FakeParallelInfo
@ddt
class TestGenerator(unittest.TestCase):
@classmethod
def setUpClass(cls):
sys.modules['_libatb_torch'] = MagicMock()
@classmethod
def tearDownClass(cls):
del sys.modules['_libatb_torch']
def setUp(self):
self.model_config = {
'backend_bin_path': '/usr/local/Ascend/mindie/2.0.RC1/mindie-llm/bin/',
'backend_log_file': '/usr/local/Ascend/mindie/2.0.RC1/mindie-service/logs/mindie-server.log',
'backend_modelInstance_id': '0', 'backend_type': 'atb', 'block_size': '128',
'cpu_mem': '5', 'deploy_type': 'INTER_PROCESS', 'dp': '1', 'executor_type': 'LLM_EXECUTOR_PYTHON',
'globalRankIds': '', 'globalWorldSize': '0', 'interNodeKmcKsfMaster': 'tools/pmt/master/ksfa',
'interNodeKmcKsfStandby': 'tools/pmt/standby/ksfb', 'interNodeTLSEnabled': '1',
'interNodeTlsCaFiles': 'ca.pem,', 'interNodeTlsCaPath': 'security/grpc/ca/',
'interNodeTlsCert': 'security/grpc/certs/server.pem', 'interNodeTlsCrlFiles': 'server_crl.pem,',
'interNodeTlsCrlPath': 'security/grpc/certs/', 'interNodeTlsPk': 'security/grpc/keys/server.key.pem',
'interNodeTlsPkPwd': 'security/grpc/pass/mindie_server_key_pwd.txt', 'isMaster': '0', 'localIP': '',
'local_rank': '0', 'masterIP': '', 'max_input_len': '2048',
'max_iter_times': '512', 'max_prefill_tokens': '8192', 'max_seq_len': '2560',
'model_id': '/home/data/llama3', 'model_instance_number': '1',
'model_instance_type': 'Standard', 'model_name': 'deepseekv2', 'moe_tp': '1',
'multiNodesInferEnabled': '0', 'multiNodesInferPort': '1120', 'npu_device_id': '0',
'npu_device_ids': '0,1,2,3,4,5,6,7', 'npu_mem': '-1', 'rank': '0', 'slaveIPs': '',
'speculation_gamma': '0', 'tp': '4', 'trust_remote_code': '0', 'world_size': '4',
'num_speculative_tokens': '0', 'max_batch_size': '5', 'max_prefill_batch_size': '5',
'distributed_enable': 'false', 'vocab_size': 100000, 'enable_warmup_with_sampling': 'false',
'cp': '1', 'sp': '1', 'moe_ep': '1'
}
fake_parallel_info = FakeParallelInfo(
dp=int(self.model_config['dp']),
tp=int(self.model_config['tp']),
sp=int(self.model_config['sp']),
cp=int(self.model_config['cp'])
)
self.fake_model_runner = FakeModelRunner(parallel_info=fake_parallel_info)
self.generation_output_prefill = GenerationOutput(
sequence_ids=np.array([0, 1, 2, 3, 4]),
parent_sequence_ids=np.array([0, 1, 2, 3, 4]),
group_indices=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
token_ids=np.array([[273], [273], [273], [273], [273]]),
logprobs=np.array([[0.], [0.], [0.], [0.], [0.]], dtype=np.float32),
top_token_ids=np.array([], dtype=np.int32),
top_logprobs=np.array([], dtype=np.float32),
num_new_tokens=np.array([1, 1, 1, 1, 1]),
num_top_tokens=np.array([0, 0, 0, 0, 0], dtype=np.int32),
cumulative_logprobs=np.array([0., 0., 0., 0., 0.], dtype=np.float32),
finish_reason=np.array([0, 0, 0, 0, 0], dtype=np.int32),
truncation_indices=np.array([0, 0, 0, 0, 0]),
current_token_indices=np.array([0, 0, 0, 0, 0], dtype=np.int32),
eos_info=None, trace_ids=[None, None, None, None, None],
simulator_ids=None
)
self.generation_output_decode = GenerationOutput(
sequence_ids=np.array([0, 1, 2, 3, 4]),
parent_sequence_ids=np.array([0, 1, 2, 3, 4]),
group_indices=[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)],
token_ids=np.array([[273], [273], [273], [273], [273]]),
logprobs=np.array([[0.], [0.], [0.], [0.], [0.]], dtype=np.float32),
top_token_ids=np.array([], dtype=np.int32),
top_logprobs=np.array([], dtype=np.float32),
num_new_tokens=np.array([1, 1, 1, 1, 1]),
num_top_tokens=np.array([0, 0, 0, 0, 0], dtype=np.int32),
cumulative_logprobs=np.array([0., 0., 0., 0., 0.], dtype=np.float32),
finish_reason=np.array([6, 6, 6, 6, 6], dtype=np.int32),
truncation_indices=np.array([0, 0, 0, 0, 0]),
current_token_indices=np.array([0, 0, 0, 0, 0], dtype=np.int32),
eos_info=None, trace_ids=[None, None, None, None, None],
simulator_ids=None
)
@patch.object(generator_torch, 'GeneratorTorch')
@patch("mindie_llm.modeling.model_wrapper.atb.atb_model_wrapper.ModelRunner")
@patch("mindie_llm.text_generator.plugins.plugin_manager.PluginManager.generate_token")
def test_init(
self,
mock_plugin_manager_generate_token,
mock_model_runner,
mock_generator_torch
):
mock_model_runner.return_value = self.fake_model_runner
fake_runner = mock_model_runner.return_value
mock_model_info = ModelInfo(fake_runner.device,
fake_runner.kv_cache_dtype,
2,
fake_runner.num_layers,
fake_runner.num_kv_heads,
fake_runner.head_size,
k_head_size=fake_runner.k_head_size,
v_head_size=fake_runner.v_head_size,
enable_nz=False,
kvcache_quant_layers=fake_runner.kvcache_quant_layers)
mock_plugin_manager_generate_token.return_value = self.generation_output_decode
fake_model_wrapper = FakeModelWrapper(mock_model_info, fake_runner)
mock_generator_torch_ins = mock_generator_torch.return_value
mock_generator_torch_ins.model_wrapper = fake_model_wrapper
mock_generator_torch_ins.obfuscation_func = None
mock_generator_torch_ins.model_info = mock_model_info
self.model_config["npu_mem"] = 10
ENV.benchmark_filepath = "./tmp.txt"
if not os.path.exists(ENV.benchmark_filepath):
with open(ENV.benchmark_filepath, 'w') as file:
file.write('Hello, world!')
os.chmod(ENV.benchmark_filepath, 0o600)
generator = Generator(self.model_config)
self.assertIsNotNone(generator)
self.assertFalse(os.path.exists(ENV.benchmark_filepath))
@data('prefill', 'decoder')
@patch.object(generator_torch, 'GeneratorTorch')
@patch("mindie_llm.modeling.model_wrapper.atb.atb_model_wrapper.ModelRunner")
@patch("mindie_llm.text_generator.plugins.plugin_manager.PluginManager.generate_token")
def test_init_with_pd_role(
self,
role,
mock_plugin_manager_generate_token,
mock_model_runner,
mock_generator_torch
):
mock_model_runner.return_value = self.fake_model_runner
fake_runner = mock_model_runner.return_value
mock_model_info = ModelInfo(fake_runner.device,
fake_runner.kv_cache_dtype,
2,
fake_runner.num_layers,
fake_runner.num_kv_heads,
fake_runner.head_size,
k_head_size=fake_runner.k_head_size,
v_head_size=fake_runner.v_head_size,
enable_nz=False,
kvcache_quant_layers=fake_runner.kvcache_quant_layers)
fake_model_wrapper = FakeModelWrapper(mock_model_info, fake_runner)
mock_generator_torch_ins = mock_generator_torch.return_value
mock_generator_torch_ins.enable_dap = False
mock_generator_torch_ins.model_wrapper = fake_model_wrapper
mock_generator_torch_ins.model_info = mock_model_info
mock_generator_torch_ins.obfuscation_func = None
mock_plugin_manager_generate_token.return_value = self.generation_output_decode
config_dict = {
'role': role,
'local_instance_id': 0,
'local_device_ip': '127.0.0.1',
'npu_device_id': 0,
'local_physical_device_id': 0,
'local_host_ip': '127.0.0.1',
'remote_device_ips': '127.0.0.2'
}
self.model_config.update(config_dict)
try:
generator = Generator(self.model_config)
self.assertIsNotNone(generator)
except RuntimeError as e:
if "Warmup failed" in str(e):
pass
else:
raise
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_generate_token_plugin_none_raise_error(self, _):
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
input_metadata = MagicMock(spec=InputMetadata)
with self.assertRaises(AttributeError):
_ = generator.generate_token(input_metadata)
@data('standard', 'flex')
@patch.object(generator_torch, 'GeneratorTorch')
@patch("mindie_llm.modeling.model_wrapper.atb.atb_model_wrapper.ModelRunner")
@patch("mindie_llm.text_generator.plugins.plugin_manager.PluginManager.generate_token")
def test_init_with_standard_flex(
self,
role,
mock_plugin_manager_generate_token,
mock_model_runner,
mock_generator_torch
):
mock_model_runner.return_value = self.fake_model_runner
fake_runner = mock_model_runner.return_value
mock_model_info = ModelInfo(fake_runner.device,
fake_runner.kv_cache_dtype,
2,
fake_runner.num_layers,
fake_runner.num_kv_heads,
fake_runner.head_size,
k_head_size=fake_runner.k_head_size,
v_head_size=fake_runner.v_head_size,
enable_nz=False,
kvcache_quant_layers=fake_runner.kvcache_quant_layers)
fake_model_wrapper = FakeModelWrapper(mock_model_info, fake_runner)
mock_generator_torch_ins = mock_generator_torch.return_value
mock_generator_torch_ins.enable_dap = False
mock_generator_torch_ins.model_wrapper = fake_model_wrapper
mock_generator_torch_ins.model_info = mock_model_info
mock_generator_torch_ins.obfuscation_func = None
mock_generator_torch_ins.backend_type = 'atb'
mock_plugin_manager_generate_token.side_effect = [
self.generation_output_prefill, self.generation_output_decode
]
config_dict = {
'role': role,
'local_instance_id': 0,
'local_device_ip': '127.0.0.1',
'npu_device_id': 0,
'local_physical_device_id': 0,
'local_host_ip': '127.0.0.1',
'remote_device_ips': '127.0.0.2'
}
self.model_config.update(config_dict)
try:
generator = Generator(self.model_config)
self.assertIsNotNone(generator)
except RuntimeError as e:
if "Warmup failed" in str(e):
pass
else:
raise
@patch.object(generator_torch, 'GeneratorTorch')
@patch("mindie_llm.modeling.model_wrapper.atb.atb_model_wrapper.ModelRunner")
@patch("mindie_llm.text_generator.plugins.plugin_manager.PluginManager.generate_token")
def test_init_with_prefixcache(
self,
mock_plugin_manager_generate_token,
mock_model_runner,
mock_generator_torch
):
config_dict = {'plugin_params': '{\"plugin_type\": \"prefix_cache\"}', 'cp': '2'}
self.model_config.update(config_dict)
fake_parallel_info = FakeParallelInfo(
dp=int(self.model_config['dp']),
tp=int(self.model_config['tp']),
sp=int(self.model_config['sp']),
cp=int(self.model_config['cp'])
)
fake_model_runner = FakeModelRunner(parallel_info=fake_parallel_info)
mock_model_runner.return_value = fake_model_runner
fake_runner = mock_model_runner.return_value
mock_model_info = ModelInfo(fake_runner.device,
fake_runner.kv_cache_dtype,
2,
fake_runner.num_layers,
fake_runner.num_kv_heads,
fake_runner.head_size,
k_head_size=fake_runner.k_head_size,
v_head_size=fake_runner.v_head_size,
enable_nz=False,
kvcache_quant_layers=fake_runner.kvcache_quant_layers)
fake_model_wrapper = FakeModelWrapper(mock_model_info, fake_runner)
mock_generator_torch_ins = mock_generator_torch.return_value
mock_generator_torch_ins.enable_dap = False
mock_generator_torch_ins.model_wrapper = fake_model_wrapper
mock_generator_torch_ins.model_info = mock_model_info
mock_generator_torch_ins.obfuscation_func = None
mock_generator_torch_ins.backend_type = 'atb'
mock_plugin_manager_generate_token.side_effect = [
self.generation_output_prefill, self.generation_output_decode,
self.generation_output_prefill, self.generation_output_decode
]
try:
generator = Generator(self.model_config)
self.assertIsNotNone(generator)
except RuntimeError as e:
if "Warmup failed" in str(e):
pass
else:
raise
@patch.object(generator_torch, 'GeneratorTorch')
@patch("mindie_llm.modeling.model_wrapper.atb.atb_model_wrapper.ModelRunner")
@patch("mindie_llm.text_generator.plugins.plugin_manager.PluginManager.generate_token")
def test_init_with_structured_output(
self,
mock_plugin_manager_generate_token,
mock_model_runner,
mock_generator_torch
):
config_dict = {'enable_structured_output': True}
self.model_config.update(config_dict)
fake_parallel_info = FakeParallelInfo(
dp=int(self.model_config['dp']),
tp=int(self.model_config['tp']),
sp=int(self.model_config['sp']),
cp=int(self.model_config['cp'])
)
fake_model_runner = FakeModelRunner(parallel_info=fake_parallel_info)
mock_model_runner.return_value = fake_model_runner
fake_runner = mock_model_runner.return_value
mock_model_info = ModelInfo(fake_runner.device,
fake_runner.kv_cache_dtype,
2,
fake_runner.num_layers,
fake_runner.num_kv_heads,
fake_runner.head_size,
k_head_size=fake_runner.k_head_size,
v_head_size=fake_runner.v_head_size,
enable_nz=False,
kvcache_quant_layers=fake_runner.kvcache_quant_layers)
fake_model_wrapper = FakeModelWrapper(mock_model_info, fake_runner)
mock_generator_torch_ins = mock_generator_torch.return_value
mock_generator_torch_ins.enable_dap = False
mock_generator_torch_ins.model_wrapper = fake_model_wrapper
mock_generator_torch_ins.model_info = mock_model_info
mock_generator_torch_ins.obfuscation_func = None
mock_generator_torch_ins.backend_type = 'atb'
mock_tokenizer = MagicMock()
mock_tokenizer.__len__ = MagicMock(return_value=1000)
mock_generator_torch_ins.tokenizer = mock_tokenizer
mock_plugin_manager_generate_token.side_effect = [
self.generation_output_prefill, self.generation_output_decode,
self.generation_output_prefill, self.generation_output_decode
]
try:
generator = Generator(self.model_config)
self.assertTrue(generator.plugin_manager._structured_output_enabled)
self.assertIsNotNone(generator.plugin_manager._structured_output_manager)
except RuntimeError as e:
if "Warmup failed" in str(e):
pass
else:
raise
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_generate(self, _):
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
input1 = [5159, 636, 374, 31346, 323, 358]
greedy_param = np.array([(1.0, 0., 0., 0.7, 3., 0.92, False, 0)], dtype=SAMPLING_DTYPE)
gen_len = 2
req1 = Request.request_from_token(input1,
sampling_params=greedy_param,
generation_params=GenerationParams(max_new_tokens=gen_len))
req1.sequences[0].block_tables = np.array([0])
req2 = Request.request_from_token(input1,
sampling_params=greedy_param,
generation_params=GenerationParams(max_new_tokens=gen_len+128))
req2.sequences[0].block_tables = np.array([1, 2])
requests = [req1, req2]
def mock_generate_token_side_effect(*args, **kwargs):
return GenerationOutput(
sequence_ids=np.array([0, 1]),
parent_sequence_ids=np.array([0, 1]),
group_indices=[(0, 1), (1, 2)]
)
generator.generate_token = MagicMock(side_effect=mock_generate_token_side_effect)
generation_output = generator.prefill(requests)
self.assertIsInstance(generation_output, GenerationOutput)
generation_output = generator.decode(requests)
self.assertIsInstance(generation_output, GenerationOutput)
req1.block_tables = np.array([0, -1])
req2.block_tables = np.array([1, 2])
generation_output = generator.generate_mix(requests, is_prefill_batch=np.array([False, True]))
self.assertIsInstance(generation_output, GenerationOutput)
self.assertEqual(generator.generate_token.call_count, 3)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_load_lora_not_active(self, _):
"""测试lora特性未使能时的返回值"""
generator = Generator(self.model_config)
generator.model_wrapper = MagicMock()
generator.model_wrapper.adapter_manager = None
generator.separate_deployment_worker = None
ret = generator.load_lora("fake_id", "fake_path")
self.assertEqual(ret, LoraOperationStatus.UNSUPPORT_CMD)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_load_lora_success(self, _):
"""测试lora特性使能时的加载成功返回值"""
generator = Generator(self.model_config)
generator.model_wrapper = MagicMock()
generator.model_wrapper.adapter_manager = MagicMock()
generator.model_wrapper.adapter_manager.load_adapter = MagicMock()
generator.separate_deployment_worker = None
ret = generator.load_lora("fake_id", "fake_path")
self.assertEqual(ret, LoraOperationStatus.LORA_CMD_SUCCESS)
@data(("LORA MEMORY ERROR", LoraOperationStatus.SLOTS_FULL),
("DUPLICATED LORA ID", LoraOperationStatus.DUPLICATED_LORA_ID),
("INVALID LORA ID", LoraOperationStatus.INVALID_LORA_ID),
("INVALID LORA RANK", LoraOperationStatus.INVALID_LORA_RANK))
@unpack
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_load_lora_fail(self, exception, expected_ret, _):
"""测试lora特性使能时的加载失败返回值"""
generator = Generator(self.model_config)
generator.model_wrapper = MagicMock()
generator.model_wrapper.adapter_manager = MagicMock()
generator.model_wrapper.adapter_manager.load_adapter = MagicMock()
generator.model_wrapper.adapter_manager.load_adapter.side_effect = Exception(exception)
generator.separate_deployment_worker = None
ret = generator.load_lora("fake_id", "fake_path")
self.assertEqual(ret, expected_ret)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_unload_lora_not_active(self, _):
"""测试lora特性未使能时的返回值"""
generator = Generator(self.model_config)
generator.model_wrapper = MagicMock()
generator.model_wrapper.adapter_manager = None
generator.separate_deployment_worker = None
ret = generator.unload_lora("fake_id")
self.assertEqual(ret, LoraOperationStatus.UNSUPPORT_CMD)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_unload_lora_success(self, _):
"""测试lora特性使能时的卸载成功返回值"""
generator = Generator(self.model_config)
generator.model_wrapper = MagicMock()
generator.model_wrapper.adapter_manager = MagicMock()
generator.model_wrapper.adapter_manager.unload_adapter = MagicMock()
generator.separate_deployment_worker = None
ret = generator.unload_lora("fake_id")
self.assertEqual(ret, LoraOperationStatus.LORA_CMD_SUCCESS)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
@patch("mindie_llm.text_generator.generator.acl")
def test_execute_recover_command_reinit_npu_success(self, mock_acl, _):
"""测试CMD_REINIT_NPU命令成功执行"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.plugin = MagicMock()
generator.plugin_manager = MagicMock()
generator.infer_context = MagicMock()
generator.infer_context.reset_all_context = MagicMock()
generator.generator_backend = MagicMock()
generator.generator_backend.execute_recover_command = MagicMock(return_value={
"command_result": 0,
"error_msg": "",
"npu_device_id": 0
})
generator.model_wrapper = MagicMock()
result = generator.execute_recover_command("CMD_REINIT_NPU")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
generator.infer_context.reset_all_context.assert_called_once()
generator.generator_backend.execute_recover_command.assert_called_once_with("CMD_REINIT_NPU")
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_reinit_npu_backend_failure(self, _):
"""测试CMD_REINIT_NPU命令后端执行失败"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.plugin = MagicMock()
generator.plugin_manager = MagicMock()
generator.infer_context = MagicMock()
generator.infer_context.reset_all_context = MagicMock()
generator.generator_backend = MagicMock()
generator.generator_backend.execute_recover_command = MagicMock(return_value={
"command_result": 1,
"error_msg": "Backend error",
"npu_device_id": 0
})
generator.model_wrapper = MagicMock()
result = generator.execute_recover_command("CMD_REINIT_NPU")
self.assertEqual(result["command_result"], 1)
self.assertEqual(result["error_msg"], "Backend error")
generator.infer_context.reset_all_context.assert_called_once()
generator.generator_backend.execute_recover_command.assert_called_once_with("CMD_REINIT_NPU")
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
@patch("mindie_llm.text_generator.generator.acl")
def test_execute_recover_command_reinit_npu_exception(self, mock_acl, _):
"""测试CMD_REINIT_NPU命令执行时抛出异常"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.plugin = MagicMock()
generator.plugin_manager = MagicMock()
generator.infer_context = MagicMock()
generator.infer_context.reset_all_context = MagicMock()
generator.generator_backend = MagicMock()
generator.generator_backend.execute_recover_command = MagicMock(side_effect=Exception("Test exception"))
generator.model_wrapper = MagicMock()
mock_acl.rt.set_device = MagicMock()
result = generator.execute_recover_command("CMD_REINIT_NPU")
self.assertEqual(result["command_result"], 1)
self.assertIn("Failed to execute recovery command", result["error_msg"])
self.assertEqual(result["npu_device_id"], 0)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
@patch("mindie_llm.text_generator.generator.time")
def test_execute_recover_command_start_engine(self, mock_time, _):
"""测试CMD_START_ENGINE命令"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = True
generator.plugin_manager = MagicMock()
generator.plugin_manager.last_sequence_ids = [1, 2, 3]
generator.plugin_manager.is_inference_pause = True
generator.plugin_manager.output_queue = None
result = generator.execute_recover_command("CMD_START_ENGINE")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
self.assertIsNone(generator.plugin_manager.last_sequence_ids)
self.assertFalse(generator.plugin_manager.is_inference_pause)
self.assertFalse(generator.is_inference_pause)
mock_time.sleep.assert_called_once_with(1)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
@patch("mindie_llm.text_generator.generator.time")
@patch("mindie_llm.text_generator.utils.model_output.ModelOutputWrapper")
def test_execute_recover_command_start_engine_with_output_queue(self, mock_model_output_wrapper_class, mock_time, _):
"""测试CMD_START_ENGINE命令,包含output_queue的情况"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = True
generator.plugin_manager = MagicMock()
generator.plugin_manager.last_sequence_ids = [1, 2, 3]
generator.plugin_manager.is_inference_pause = True
mock_queue = MagicMock()
mock_queue.empty = MagicMock(return_value=True)
mock_queue.put = MagicMock()
generator.plugin_manager.output_queue = mock_queue
mock_empty_output = MagicMock()
mock_model_output_wrapper_class.make_empty = MagicMock(return_value=mock_empty_output)
result = generator.execute_recover_command("CMD_START_ENGINE")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
self.assertIsNone(generator.plugin_manager.last_sequence_ids)
self.assertFalse(generator.plugin_manager.is_inference_pause)
self.assertFalse(generator.is_inference_pause)
mock_time.sleep.assert_called_once_with(1)
mock_queue.put.assert_called_once_with(mock_empty_output)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
@patch("mindie_llm.text_generator.generator.time")
def test_execute_recover_command_start_engine_output_queue_not_empty(self, mock_time, _):
"""测试CMD_START_ENGINE命令,output_queue不为空的情况"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = True
generator.plugin_manager = MagicMock()
generator.plugin_manager.last_sequence_ids = [1, 2, 3]
generator.plugin_manager.is_inference_pause = True
mock_queue = MagicMock()
mock_queue.empty = MagicMock(return_value=False)
mock_queue.put = MagicMock()
generator.plugin_manager.output_queue = mock_queue
result = generator.execute_recover_command("CMD_START_ENGINE")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
self.assertIsNone(generator.plugin_manager.last_sequence_ids)
self.assertFalse(generator.plugin_manager.is_inference_pause)
self.assertFalse(generator.is_inference_pause)
mock_time.sleep.assert_called_once_with(1)
mock_queue.put.assert_not_called()
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_pause_engine(self, _):
"""测试CMD_PAUSE_ENGINE命令"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = False
generator.plugin_manager = MagicMock()
generator.plugin_manager.is_inference_pause = False
generator.generator_backend = MagicMock()
generator.generator_backend.execute_recover_command = MagicMock(return_value={
"command_result": 0,
"error_msg": "",
"npu_device_id": 0
})
result = generator.execute_recover_command("CMD_PAUSE_ENGINE")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
self.assertTrue(generator.is_inference_pause)
self.assertTrue(generator.plugin_manager.is_inference_pause)
generator.generator_backend.execute_recover_command.assert_called_once_with("CMD_PAUSE_ENGINE")
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_pause_engine_backend_failure(self, _):
"""测试CMD_PAUSE_ENGINE命令后端执行失败时仍正确设置pause状态"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = False
generator.plugin_manager = MagicMock()
generator.plugin_manager.is_inference_pause = False
generator.generator_backend = MagicMock()
generator.generator_backend.execute_recover_command = MagicMock(return_value={
"command_result": 1,
"error_msg": "Stop device failed",
"npu_device_id": 0
})
result = generator.execute_recover_command("CMD_PAUSE_ENGINE")
self.assertEqual(result["command_result"], 1)
self.assertEqual(result["error_msg"], "Stop device failed")
self.assertEqual(result["npu_device_id"], 0)
self.assertTrue(generator.is_inference_pause)
self.assertTrue(generator.plugin_manager.is_inference_pause)
generator.generator_backend.execute_recover_command.assert_called_once_with("CMD_PAUSE_ENGINE")
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_pause_engine_roce(self, _):
"""测试CMD_PAUSE_ENGINE_ROCE命令"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
generator.is_inference_pause = False
generator.plugin_manager = MagicMock()
generator.plugin_manager.is_inference_pause = False
result = generator.execute_recover_command("CMD_PAUSE_ENGINE_ROCE")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
self.assertTrue(generator.is_inference_pause)
self.assertTrue(generator.plugin_manager.is_inference_pause)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_clear_transer(self, _):
"""测试CMD_CLEAR_TRANSER命令"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
result = generator.execute_recover_command("CMD_CLEAR_TRANSER")
self.assertEqual(result["command_result"], 0)
self.assertEqual(result["error_msg"], "")
self.assertEqual(result["npu_device_id"], 0)
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_execute_recover_command_unknown_command(self, _):
"""测试未知命令"""
generator = Generator(self.model_config)
generator.separate_deployment_worker = None
generator.backend_type = 'atb'
generator.npu_device_id = 0
result = generator.execute_recover_command("CMD_UNKNOWN")
self.assertEqual(result["command_result"], 1)
self.assertIn("Unknown recovery command", result["error_msg"])
self.assertEqual(result["npu_device_id"], 0)
@patch('mindie_llm.utils.prof.profiler.span_end')
@patch('mindie_llm.utils.prof.profiler.span_attr')
@patch('mindie_llm.utils.prof.profiler.span_start')
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_generate_token_sets_fault_device_when_exception_maps_to_error_code(
self, _, mock_span_start, mock_span_attr, mock_span_end
):
"""异常信息命中 convert_exception_to_error_code 时设置 is_fault_device 并抛出 ErrorCodeException。"""
mock_span_start.return_value = None
generator = Generator(self.model_config)
generator.pd_config = MagicMock()
generator.pd_config.model_role = STANDARD_TAG
generator.input_metadata_queue = queue.Queue()
generator.rank = 0
generator.async_inference = False
generator.plugin_manager = MagicMock()
generator.plugin_manager.generate_token.side_effect = RuntimeError(
"backend reported MIE05E0000005 in stack"
)
generator.generator_backend = MagicMock()
generator.generator_backend.is_fault_device = False
im = MagicMock(spec=InputMetadata)
im.batch_seq_len = np.array([0])
im.is_prefill = False
with self.assertRaises(ErrorCodeException) as cm:
generator.generate_token(im, warmup=False)
self.assertEqual(cm.exception.error_code, ErrorCode.TEXT_GENERATOR_OUT_OF_MEMORY)
self.assertTrue(generator.generator_backend.is_fault_device)
@patch('mindie_llm.utils.prof.profiler.span_end')
@patch('mindie_llm.utils.prof.profiler.span_attr')
@patch('mindie_llm.utils.prof.profiler.span_start')
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_generate_token_notify_force_stop_when_inference_paused(
self, _, mock_span_start, mock_span_attr, mock_span_end
):
"""推理暂停时 FORCE STOP 异常应调用 notify_force_stop_exception 并返回空 GenerationOutput。"""
mock_span_start.return_value = None
generator = Generator(self.model_config)
generator.pd_config = MagicMock()
generator.pd_config.model_role = STANDARD_TAG
generator.input_metadata_queue = queue.Queue()
generator.rank = 0
generator.async_inference = False
generator.is_inference_pause = True
generator.plugin_manager = MagicMock()
generator.plugin_manager.generate_token.side_effect = RuntimeError("User FORCE STOP request")
generator.generator_backend = MagicMock()
generator.generator_backend.notify_force_stop_exception = MagicMock()
im = MagicMock(spec=InputMetadata)
im.batch_seq_len = np.array([0])
im.is_prefill = False
out = generator.generate_token(im, warmup=False)
self.assertIsInstance(out, GenerationOutput)
self.assertEqual(out.sequence_ids.size, 0)
generator.generator_backend.notify_force_stop_exception.assert_called_once()
@patch('mindie_llm.utils.prof.profiler.span_end')
@patch('mindie_llm.utils.prof.profiler.span_attr')
@patch('mindie_llm.utils.prof.profiler.span_start')
@patch("mindie_llm.text_generator.generator.Generator.__init__", return_value=None)
def test_generate_token_force_stop_reraises_when_not_paused(
self, _, mock_span_start, mock_span_attr, mock_span_end
):
"""非暂停状态下 FORCE STOP 仍按未知异常向上抛出。"""
mock_span_start.return_value = None
generator = Generator(self.model_config)
generator.pd_config = MagicMock()
generator.pd_config.model_role = STANDARD_TAG
generator.input_metadata_queue = queue.Queue()
generator.rank = 0
generator.async_inference = False
generator.is_inference_pause = False
generator.plugin_manager = MagicMock()
err = RuntimeError("FORCE STOP abort")
generator.plugin_manager.generate_token.side_effect = err
generator.generator_backend = MagicMock()
im = MagicMock(spec=InputMetadata)
im.batch_seq_len = np.array([0])
im.is_prefill = False
with self.assertRaises(RuntimeError) as cm:
generator.generate_token(im, warmup=False)
self.assertIs(cm.exception, err)
generator.generator_backend.notify_force_stop_exception.assert_not_called()
class TestPDInterface(unittest.TestCase):
def setUp(self):
self.config_dict = {
'role': 'standard',
'local_instance_id': 0,
'local_device_ip': '127.0.0.1',
'npu_device_id': 0,
'local_physical_device_id': 0,
'local_host_ip': '127.0.0.1',
'remote_device_ips': '127.0.0.2'
}
self.pd_config = PDModelConfig(self.config_dict)
self.pd_interface = PDInterface(self.pd_config)
self.original_npu = globals().get('npu', None)
npu_mock = MagicMock()
npu_mock.set_device = MagicMock()
npu_mock.max_memory_allocated = MagicMock(return_value=1024)
globals()['npu'] = npu_mock
def tearDown(self):
if self.original_npu is not None:
globals()['npu'] = self.original_npu
def test_link(self):
"""测试 link 方法"""
worker_mock = MagicMock()
worker_mock.link.return_value = "dummy_link"
self.pd_interface.separate_deployment_worker = worker_mock
remote_cluster_ids = {1: [10, 11]}
remote_physical_device_ids = {1: [20, 21]}
remote_device_ips = {1: ["192.168.1.2", "192.168.1.3"]}
host_ips = {1: ["192.168.1.100", "192.168.1.101"]}
remote_super_device_ids = {1: [8650754, 8650755]}
remote_super_pod_ids = {1: [0, 0]}
self.pd_interface.link(
remote_cluster_ids=remote_cluster_ids,
remote_physical_device_ids=remote_physical_device_ids,
remote_device_ips=remote_device_ips,
host_ips=host_ips,
remote_super_device_ids=remote_super_device_ids,
remote_super_pod_ids=remote_super_pod_ids,
)
worker_mock.link.assert_called_once_with(
remote_cluster_ids=remote_cluster_ids,
remote_physical_device_ids=remote_physical_device_ids,
remote_device_ips=remote_device_ips,
host_ips=host_ips,
remote_super_device_ids=remote_super_device_ids,
remote_super_pod_ids=remote_super_pod_ids
)
def test_unlink(self):
"""测试 unlink 方法"""
worker_mock = MagicMock()
worker_mock.unlink.return_value = "dummy_unlink"
self.pd_interface.separate_deployment_worker = worker_mock
result = self.pd_interface.unlink(1)
self.assertEqual(result, "dummy_unlink")
worker_mock.unlink.assert_called_once_with(1)
def test_switch_role(self):
"""测试 switch_role 方法"""
new_role = 'new_role'
self.pd_interface.switch_role(new_role)
self.assertEqual(self.pd_interface.pd_config.model_role, new_role)
def test_pull_kv_success(self):
"""测试 pull_kv 成功分支"""
self.pd_interface.device_inited = False
worker_mock = MagicMock()
worker_mock.pull_blocks.return_value = MindieLlmStatusCode.SUCCESS
self.pd_interface.separate_deployment_worker = worker_mock
dummy_input_metadata = MagicMock()
dummy_input_metadata.is_prefill = False
pd_infos = [(10, [1, 2], [3, 4])]
ret, model_instance_id = self.pd_interface.pull_kv(dummy_input_metadata, pd_infos)
self.assertEqual(ret, MindieLlmStatusCode.SUCCESS)
self.assertEqual(model_instance_id, 0)
worker_mock.pull_blocks.assert_called_once_with(remote_model_instance_id=10,
src_block_table=[1, 2],
dst_block_table=[3, 4])
def test_pull_kv_failure(self):
"""测试 pull_kv 当 pull_blocks 返回错误时直接返回错误"""
self.pd_interface.device_inited = False
worker_mock = MagicMock()
worker_mock.pull_blocks.return_value = "FAIL"
self.pd_interface.separate_deployment_worker = worker_mock
dummy_input_metadata = MagicMock()
dummy_input_metadata.is_prefill = False
pd_infos = [(99, [1], [2])]
ret, model_instance_id = self.pd_interface.pull_kv(dummy_input_metadata, pd_infos)
self.assertEqual(ret, "FAIL")
self.assertEqual(model_instance_id, 99)
self.assertTrue(self.pd_interface.input_metadata_queue.empty())
worker_mock.pull_blocks.assert_called_once_with(remote_model_instance_id=99,
src_block_table=[1],
dst_block_table=[2])
@patch('mindie_llm.text_generator.utils.separate_deployment_engine.LLMDataDist')
@patch('mindie_llm.text_generator.utils.separate_deployment_engine.LLMDataDistConfig')
def test_init_sepd_engine(self, mock_llm_data_dist_config, mock_llm_data_dist):
"""测试 _init_sepd_engine"""
self.config_dict = {
'role': 'flex',
'local_instance_id': 0,
'local_device_ip': '127.0.0.1',
'npu_device_id': 0,
'local_physical_device_id': 0,
'local_host_ip': '127.0.0.1',
'remote_device_ips': '127.0.0.2',
'local_super_device_id': 0,
'local_super_pod_id': 0
}
self.pd_config = PDModelConfig(self.config_dict)
self.pd_interface = PDInterface(self.pd_config)
self.pd_interface._init_sepd_engine()
if __name__ == "__main__":
unittest.main()