import os
import sys
import unittest
from unittest.mock import MagicMock
os.environ.setdefault("MINDIE_LLM_BENCHMARK_FILEPATH", "/tmp/benchmark.jsonl")
import torch
if not hasattr(torch, "npu"):
torch.npu = MagicMock()
torch.npu.config.allow_internal_format = True
torch.npu.current_stream.return_value.synchronize.return_value = None
torch.npu.FloatTensor = MagicMock
torch.npu.IntTensor = MagicMock
from mindie_llm.runtime.utils.npu.device_utils import DeviceType
mock_node_info = MagicMock()
mock_node_info.get_device_type.return_value = DeviceType.ASCEND_910_93
mock_node_info.get_hbm_capacity.return_value = 0
mock_node_info.get_hbm_usage.return_value = 0
import mindie_llm.runtime.utils.npu.device_utils as device_utils_mod
device_utils_mod.get_npu_node_info = MagicMock(return_value=mock_node_info)
device_utils_mod.get_npu_hbm_info = MagicMock()
mock_mie_ops = MagicMock()
sys.modules["mindie_llm.runtime.ops.mie_ops"] = mock_mie_ops
if "mindie_llm.runtime.model_runner.model_runner_exp" in sys.modules:
del sys.modules["mindie_llm.runtime.model_runner.model_runner_exp"]
from mindie_llm.runtime.model_runner import model_runner_exp
from mindie_llm.runtime.model_runner.model_runner_exp import ModelRunnerExp
_ModelRunnerExpClass = getattr(ModelRunnerExp, "__wrapped__", ModelRunnerExp)
class TestModelRunnerExpSourceStructure(unittest.TestCase):
"""Structural tests that verify the source has the expected decorators."""
def test_decorator_import_exists(self):
"""The file should import the exception_handler."""
import ast
import inspect
source = inspect.getsource(model_runner_exp)
tree = ast.parse(source)
found_import = False
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
names = [alias.name for alias in node.names]
if "exception_handler" in names:
found_import = True
break
self.assertTrue(found_import, "@exception_handler import not found in model_runner_exp.py")
def test_exception_handler_decorator_before_class(self):
"""The @exception_handler decorator should appear before class ModelRunnerExp."""
import ast
import inspect
source = inspect.getsource(model_runner_exp)
tree = ast.parse(source)
found_decorator = False
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "ModelRunnerExp":
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "exception_handler":
found_decorator = True
break
elif isinstance(decorator, ast.Attribute) and decorator.attr == "exception_handler":
found_decorator = True
break
break
self.assertTrue(found_decorator, "@exception_handler decorator not found on ModelRunnerExp")
def test_auto_speculative_method_router_present(self):
"""@auto_speculative_method_router should still be present as outer decorator."""
import ast
import inspect
source = inspect.getsource(model_runner_exp)
tree = ast.parse(source)
found_router = False
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "ModelRunnerExp":
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
func = decorator.func
if isinstance(func, ast.Attribute) and "auto_speculative_method_router" in func.attr:
found_router = True
break
elif isinstance(func, ast.Name) and func.id == "auto_speculative_method_router":
found_router = True
break
break
self.assertTrue(found_router, "@auto_speculative_method_router decorator not found on ModelRunnerExp")
class TestModelRunnerExpOomContract(unittest.TestCase):
"""Verify the OOM contract: forward/compile/load_weights are wrapped."""
def test_forward_is_wrapped(self):
"""forward method should be wrapped by _torch_oom_handler (has __wrapped__)."""
forward = _ModelRunnerExpClass.__dict__.get("forward")
self.assertIsNotNone(forward)
self.assertTrue(hasattr(forward, "__wrapped__"), "forward should be wrapped by exception_handler")
def test_compile_is_wrapped(self):
"""compile method should be wrapped by _torch_oom_handler."""
compile_method = _ModelRunnerExpClass.__dict__.get("compile")
self.assertIsNotNone(compile_method)
self.assertTrue(hasattr(compile_method, "__wrapped__"), "compile should be wrapped by exception_handler")
def test_load_weights_is_wrapped(self):
"""load_weights method should be wrapped by _torch_oom_handler."""
lw = _ModelRunnerExpClass.__dict__.get("load_weights")
self.assertIsNotNone(lw)
self.assertTrue(hasattr(lw, "__wrapped__"), "load_weights should be wrapped by exception_handler")
if __name__ == "__main__":
unittest.main()