import os
import unittest
from unittest.mock import patch
os.environ.setdefault("MINDIE_LLM_BENCHMARK_FILEPATH", "/tmp/benchmark.jsonl")
import torch
from mindie_llm.utils.decorators.exception_handler import (
_is_target_method,
_torch_oom_handler,
exception_handler,
)
class TestTargetMethodFilter(unittest.TestCase):
"""Test _is_target_method filtering logic."""
def test_includes_forward(self):
self.assertTrue(_is_target_method("forward"))
def test_includes_compile(self):
self.assertTrue(_is_target_method("compile"))
def test_includes_load_weights(self):
self.assertTrue(_is_target_method("load_weights"))
def test_includes_forward_with_suffix(self):
self.assertTrue(_is_target_method("forward_layerwise_disaggregated"))
def test_excludes_generate_position_ids(self):
self.assertFalse(_is_target_method("generate_position_ids"))
def test_excludes_dunder_methods(self):
self.assertFalse(_is_target_method("__init__"))
self.assertFalse(_is_target_method("__call__"))
def test_excludes_arbitrary_method(self):
self.assertFalse(_is_target_method("some_other_method"))
class TestTorchOomHandler(unittest.TestCase):
"""Test _torch_oom_handler wrapper behavior."""
def test_oom_caught_and_re_raised_as_runtime_error(self):
"""torch.OutOfMemoryError should be caught, logged, and re-raised as RuntimeError."""
def failing_func():
raise torch.OutOfMemoryError("test OOM")
wrapped = _torch_oom_handler(failing_func)
with self.assertRaises(RuntimeError) as cm:
wrapped()
self.assertIn("MIE05E000006", str(cm.exception))
def test_oom_logger_called_with_error_code(self):
"""Logger.error should be called with ACL_GRAPH_OUT_OF_MEMORY error code."""
from mindie_llm.utils.log.error_code import ErrorCode
def failing_func():
raise torch.OutOfMemoryError("test OOM")
wrapped = _torch_oom_handler(failing_func)
with patch("mindie_llm.utils.decorators.exception_handler.logger") as mock_logger:
with self.assertRaises(RuntimeError):
wrapped()
mock_logger.error.assert_called_once()
args, _ = mock_logger.error.call_args
self.assertIs(args[1], ErrorCode.ACL_GRAPH_OUT_OF_MEMORY)
def test_oom_preserves_original_exception_chain(self):
"""The original torch.OutOfMemoryError should be chained via __cause__."""
def failing_func():
raise torch.OutOfMemoryError("test OOM")
wrapped = _torch_oom_handler(failing_func)
with self.assertRaises(RuntimeError) as cm:
wrapped()
self.assertIsInstance(cm.exception.__cause__, torch.OutOfMemoryError)
def test_non_oom_exception_passes_through(self):
"""Non-OOM exceptions should be re-raised unchanged."""
def failing_func():
raise ValueError("some other error")
wrapped = _torch_oom_handler(failing_func)
with self.assertRaises(ValueError):
wrapped()
def test_successful_call_returns_result(self):
"""A successful call should return the original result."""
def successful_func():
return 42
wrapped = _torch_oom_handler(successful_func)
result = wrapped()
self.assertEqual(result, 42)
def test_oom_skipped_when_torch_lacks_oom_error(self):
"""If torch.OutOfMemoryError is unavailable, non-OOM errors should still pass through."""
oom_cls = getattr(torch, "OutOfMemoryError", None)
if oom_cls is None:
self.skipTest("torch.OutOfMemoryError not available, nothing to mask")
def failing_func():
raise ValueError("non-OOM error")
wrapped = _torch_oom_handler(failing_func)
del torch.OutOfMemoryError
try:
with self.assertRaises(ValueError):
wrapped()
finally:
torch.OutOfMemoryError = oom_cls
class TestExceptionHandlerDecorator(unittest.TestCase):
"""Test @exception_handler class decorator."""
def test_wraps_forward_method(self):
"""forward() should be wrapped by the handler."""
@exception_handler
class FakeRunner:
def forward(self):
raise torch.OutOfMemoryError("test OOM")
runner = FakeRunner()
with self.assertRaises(RuntimeError):
runner.forward()
def test_wraps_compile_method(self):
"""compile() should be wrapped by the handler."""
@exception_handler
class FakeRunner:
def compile(self):
raise torch.OutOfMemoryError("test OOM")
runner = FakeRunner()
with self.assertRaises(RuntimeError):
runner.compile()
def test_wraps_load_weights_method(self):
"""load_weights() should be wrapped by the handler."""
@exception_handler
class FakeRunner:
def load_weights(self):
raise torch.OutOfMemoryError("test OOM")
runner = FakeRunner()
with self.assertRaises(RuntimeError):
runner.load_weights()
def test_skips_generate_position_ids(self):
"""generate_position_ids() should NOT be wrapped."""
@exception_handler
class FakeRunner:
def generate_position_ids(self):
return "not wrapped"
runner = FakeRunner()
result = runner.generate_position_ids()
self.assertEqual(result, "not wrapped")
def test_non_oom_passes_through_on_wrapped_method(self):
"""Wrapped methods should still let non-OOM exceptions through."""
@exception_handler
class FakeRunner:
def forward(self):
raise ValueError("not OOM")
runner = FakeRunner()
with self.assertRaises(ValueError):
runner.forward()
def test_works_with_classmethod(self):
"""@exception_handler should handle classmethod correctly."""
@exception_handler
class FakeRunner:
@classmethod
def forward(cls):
raise torch.OutOfMemoryError("test OOM")
with self.assertRaises(RuntimeError):
FakeRunner.forward()
def test_works_with_staticmethod(self):
"""@exception_handler should handle staticmethod correctly."""
@exception_handler
class FakeRunner:
@staticmethod
def forward():
raise torch.OutOfMemoryError("test OOM")
with self.assertRaises(RuntimeError):
FakeRunner.forward()
def test_non_oom_methods_unchanged(self):
"""Methods not in the target list should behave normally."""
@exception_handler
class FakeRunner:
def some_helper(self):
return "unchanged"
runner = FakeRunner()
self.assertEqual(runner.some_helper(), "unchanged")
def test_wraps_functions_starting_with_compile(self):
"""Any method starting with 'compile' should be wrapped."""
@exception_handler
class FakeRunner:
def compile_special(self):
raise torch.OutOfMemoryError("test OOM")
runner = FakeRunner()
with self.assertRaises(RuntimeError):
runner.compile_special()
if __name__ == "__main__":
unittest.main()