"""
msmodelslim.utils.function_hijacker 模块的单元测试
"""
import unittest
from unittest.mock import patch, MagicMock
from msmodelslim.utils.function_hijacker import (
_get_target_id, _create_kwargs_wrapper, FunctionHijacker,
hijack_function, restore_function, restore_all_hijacked,
get_original_function, _hijacker
)
def original_func_for_test(a, b=10, c=None):
"""用于测试的原始函数"""
return a + b + (c or 0)
def replacement_func_for_test(a, b=10, c=None):
"""用于测试的替换函数"""
return a * b * (c or 1)
def original_func_with_args_kwargs(*args, **kwargs):
"""用于测试带*args和**kwargs的原始函数"""
return sum(args) + sum(kwargs.values())
def replacement_func_with_args_kwargs(**kwargs):
"""用于测试带*args和**kwargs的替换函数"""
raw_args = kwargs.get('_raw_args', ())
if raw_args:
args_sum = sum([x for x in raw_args if isinstance(x, (int, float))])
else:
args_sum = sum([x for x in kwargs.get('args', ()) if isinstance(x, (int, float))])
kwargs_sum = sum([v for v in kwargs.get('kwargs', {}).values() if isinstance(v, (int, float))])
other_sum = sum([v for k, v in kwargs.items() if
k not in ['args', 'kwargs', '_raw_args'] and isinstance(v, (int, float))])
return args_sum + kwargs_sum + other_sum
def replacement_func_for_inspection_error(**kwargs):
"""用于测试检查错误的替换函数"""
raw_args = kwargs.get('_raw_args', ())
numeric_args = [x for x in raw_args if isinstance(x, (int, float))]
numeric_kwargs = [v for k, v in kwargs.items() if k != '_raw_args' and isinstance(v, (int, float))]
return sum(numeric_args) + sum(numeric_kwargs)
def replacement_func_for_bind_error(**kwargs):
"""用于测试绑定错误的替换函数"""
raw_args = kwargs.get('_raw_args', ())
if len(raw_args) >= 2:
return raw_args[0] + raw_args[1]
else:
return kwargs.get('a', 0) + kwargs.get('b', 0)
def replacement_func_for_complex_func(**kwargs):
"""用于测试复杂函数的替换函数"""
raw_args = kwargs.get('_raw_args', ())
kw_arg = kwargs.get('kw_arg', 1)
extra = kwargs.get('extra', 0)
numeric_raw_args = [x for x in raw_args if isinstance(x, (int, float))]
if raw_args:
first_arg = raw_args[0]
return first_arg * kw_arg * (sum(numeric_raw_args) + 1) * (extra + 1)
else:
return 0
def complex_func_for_test(a: int, b: str = "default", *args, **kwargs) -> str:
"""用于测试复杂函数的原始函数"""
return f"{a}-{b}"
def replacement_func_for_complex_test(**kwargs):
"""用于测试复杂函数的替换函数"""
a = kwargs.get('a', 0)
b = kwargs.get('b', 'default')
return f"replaced-{a}-{b}"
class TestTargetId(unittest.TestCase):
"""测试_get_target_id函数"""
def test_get_target_id_basic(self):
"""测试基本的target_id生成"""
class TestContainer:
def test_func(self):
pass
container = TestContainer()
target = (container, 'test_func')
target_id = _get_target_id(target)
self.assertIsInstance(target_id, str)
self.assertIn(f"{id(container)}:", target_id)
self.assertIn(":test_func", target_id)
def test_get_target_id_different_objects(self):
"""测试不同对象的target_id唯一性"""
class TestContainer:
def func1(self):
pass
def func2(self):
pass
container1 = TestContainer()
container2 = TestContainer()
target1 = (container1, 'func1')
target2 = (container2, 'func1')
target3 = (container1, 'func2')
id1 = _get_target_id(target1)
id2 = _get_target_id(target2)
id3 = _get_target_id(target3)
self.assertNotEqual(id1, id2)
self.assertNotEqual(id1, id3)
class TestCreateKwargsWrapper(unittest.TestCase):
"""测试_create_kwargs_wrapper函数"""
def test_wrapper_with_inspectable_function(self):
"""测试可检查函数的包装器"""
wrapper = _create_kwargs_wrapper(original_func_for_test, replacement_func_for_test)
result = wrapper(2, 3, 4)
self.assertEqual(result, 24)
result = wrapper(a=2, b=3, c=4)
self.assertEqual(result, 24)
result = wrapper(a=2, b=3)
self.assertEqual(result, 6)
def test_wrapper_with_args(self):
"""测试包含位置参数的包装器"""
wrapper = _create_kwargs_wrapper(
original_func_with_args_kwargs, replacement_func_with_args_kwargs
)
result = wrapper(1, 2, 3, x=4, y=5)
self.assertEqual(result, 15)
def test_wrapper_with_inspection_error(self):
"""测试检查时出错的情况"""
original_func = MagicMock()
original_func.__name__ = 'test_func'
with patch('inspect.signature', side_effect=TypeError("Cannot inspect")):
wrapper = _create_kwargs_wrapper(
original_func, replacement_func_for_inspection_error
)
result = wrapper(x=1, y=2)
self.assertEqual(result, 3)
def test_wrapper_with_inspection_error_with_args(self):
"""测试检查时出错且包含位置参数的情况"""
original_func = MagicMock()
original_func.__name__ = 'test_func'
with patch('inspect.signature', side_effect=TypeError("Cannot inspect")):
wrapper = _create_kwargs_wrapper(
original_func, replacement_func_for_inspection_error
)
result = wrapper(1, 2, 3, x=4, y=5)
self.assertEqual(result, 15)
def test_wrapper_with_bind_error(self):
"""测试绑定参数时出错的情况"""
def original_func(a, b):
return a + b
wrapper = _create_kwargs_wrapper(original_func, replacement_func_for_bind_error)
result = wrapper(1, 2, 3)
self.assertEqual(result, 3)
def test_wrapper_with_inspect_signature_error(self):
"""测试inspect.signature出错的情况"""
def original_func(a, b):
return a + b
def replacement_func(**kwargs):
return kwargs.get('a', 0) + kwargs.get('b', 0)
with patch('inspect.signature', side_effect=ValueError("Cannot inspect")):
wrapper = _create_kwargs_wrapper(original_func, replacement_func)
result = wrapper(1, 2)
self.assertIsInstance(result, int)
def test_wrapper_with_no_args(self):
"""测试无参数函数的包装器"""
def original_func_no_args():
return 42
def replacement_func_no_args():
return 84
wrapper = _create_kwargs_wrapper(original_func_no_args, replacement_func_no_args)
result = wrapper()
self.assertEqual(result, 84)
class TestFunctionHijacker(unittest.TestCase):
"""测试FunctionHijacker类"""
def setUp(self):
"""每个测试方法前的设置"""
self.hijacker = FunctionHijacker()
class TestContainer:
@staticmethod
def original_func(x, y=10):
return x + y
@staticmethod
def another_func(z):
return z * 2
self.container = TestContainer()
def test_initialization(self):
"""测试初始化"""
hijacker = FunctionHijacker()
self.assertIsInstance(hijacker.original_functions, dict)
self.assertIsInstance(hijacker.hijacked_targets, dict)
self.assertEqual(len(hijacker.original_functions), 0)
self.assertEqual(len(hijacker.hijacked_targets), 0)
def test_hijack_function_basic(self):
"""测试基本的函数劫持"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
self.hijacker.hijack_function(target, replacement_func)
result = self.container.original_func(2, 3)
self.assertEqual(result, 6)
original_func = self.hijacker.get_original_function(target)
self.assertIsNotNone(original_func)
self.assertEqual(original_func(2, 3), 5)
def test_hijack_function_update(self):
"""测试更新已劫持的函数"""
def replacement_func1(x, y=10):
return x * y
def replacement_func2(x, y=10):
return x ** y
target = (self.container, 'original_func')
self.hijacker.hijack_function(target, replacement_func1)
result1 = self.container.original_func(2, 3)
self.assertEqual(result1, 6)
self.hijacker.hijack_function(target, replacement_func2)
result2 = self.container.original_func(2, 3)
self.assertEqual(result2, 8)
def test_restore_function(self):
"""测试恢复函数"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
self.hijacker.hijack_function(target, replacement_func)
hijacked_result = self.container.original_func(2, 3)
self.assertEqual(hijacked_result, 6)
self.hijacker.restore_function(target)
restored_result = self.container.original_func(2, 3)
self.assertEqual(restored_result, 5)
def test_restore_function_not_found(self):
"""测试恢复不存在的函数"""
target = (self.container, 'nonexistent_func')
self.hijacker.restore_function(target)
def test_restore_all(self):
"""测试恢复所有函数"""
def replacement_func1(x, y=10):
return x * y
def replacement_func2(z):
return z ** 2
target1 = (self.container, 'original_func')
target2 = (self.container, 'another_func')
self.hijacker.hijack_function(target1, replacement_func1)
self.hijacker.hijack_function(target2, replacement_func2)
self.assertEqual(self.container.original_func(2, 3), 6)
self.assertEqual(self.container.another_func(4), 16)
self.hijacker.restore_all()
self.assertEqual(self.container.original_func(2, 3), 5)
self.assertEqual(self.container.another_func(4), 8)
def test_get_original_function(self):
"""测试获取原函数"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
original = self.hijacker.get_original_function(target)
self.assertIsNone(original)
self.hijacker.hijack_function(target, replacement_func)
original = self.hijacker.get_original_function(target)
self.assertIsNotNone(original)
self.assertEqual(original(2, 3), 5)
def test_hijack_nonexistent_attribute(self):
"""测试劫持不存在的属性"""
def replacement_func():
return "replaced"
target = (self.container, 'nonexistent_attr')
with self.assertRaises(AttributeError):
self.hijacker.hijack_function(target, replacement_func)
def test_complex_function_with_kwargs(self):
"""测试复杂的函数劫持(带kwargs)"""
def original_func(a, b=1, *args, **kwargs):
numeric_args = [arg for arg in args if isinstance(arg, (int, float))]
numeric_kwargs = [v for v in kwargs.values() if isinstance(v, (int, float))]
return a + b + sum(numeric_args) + sum(numeric_kwargs)
def replacement_func(**kwargs):
a = kwargs.get('a', 0)
b = kwargs.get('b', 1)
args_tuple = kwargs.get('args', ())
kwargs_dict = kwargs.get('kwargs', {})
numeric_args = [x for x in args_tuple if isinstance(x, (int, float))]
numeric_kwargs = [v for v in kwargs_dict.values() if isinstance(v, (int, float))]
return a * b * (sum(numeric_args) + 1) * (sum(numeric_kwargs) + 1)
self.container.original_func = original_func
target = (self.container, 'original_func')
self.hijacker.hijack_function(target, replacement_func)
result = self.container.original_func(2, 3, 4, 5, x=6, y=7)
self.assertEqual(result, 840)
def test_hijack_builtin_function(self):
"""测试劫持内置函数(通过模拟)"""
class MockBuiltin:
def __init__(self):
self.builtin_func = len
mock_obj = MockBuiltin()
def replacement_func(*args, **kwargs):
return 42
target = (mock_obj, 'builtin_func')
self.hijacker.hijack_function(target, replacement_func)
result = mock_obj.builtin_func([1, 2, 3])
self.assertEqual(result, 42)
def test_hijacker_with_no_hijacks(self):
"""测试没有劫持时的恢复操作"""
self.hijacker.restore_all()
original = self.hijacker.get_original_function(('dummy', 'dummy'))
self.assertIsNone(original)
class TestGlobalFunctions(unittest.TestCase):
"""测试全局函数"""
def setUp(self):
"""每个测试方法前的设置"""
_hijacker.original_functions.clear()
_hijacker.hijacked_targets.clear()
class TestContainer:
@staticmethod
def original_func(x, y=10):
return x + y
self.container = TestContainer()
def test_global_hijack_function(self):
"""测试全局hijack_function函数"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
hijack_function(target, replacement_func)
result = self.container.original_func(2, 3)
self.assertEqual(result, 6)
def test_global_restore_function(self):
"""测试全局restore_function函数"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
hijack_function(target, replacement_func)
self.assertEqual(self.container.original_func(2, 3), 6)
restore_function(target)
self.assertEqual(self.container.original_func(2, 3), 5)
def test_global_restore_all_hijacked(self):
"""测试全局restore_all_hijacked函数"""
def replacement_func1(x, y=10):
return x * y
def replacement_func2(x, y=10):
return x ** y
target1 = (self.container, 'original_func')
target2 = (self.container, 'original_func')
hijack_function(target1, replacement_func1)
hijack_function(target1, replacement_func2)
self.assertEqual(self.container.original_func(2, 3), 8)
restore_all_hijacked()
self.assertEqual(self.container.original_func(2, 3), 5)
def test_global_get_original_function(self):
"""测试全局get_original_function函数"""
def replacement_func(x, y=10):
return x * y
target = (self.container, 'original_func')
original = get_original_function(target)
self.assertIsNone(original)
hijack_function(target, replacement_func)
original = get_original_function(target)
self.assertIsNotNone(original)
self.assertEqual(original(2, 3), 5)
def test_global_functions_with_multiple_targets(self):
"""测试多个目标的全局函数"""
def replacement_func1(x, y=10):
return x * y
def replacement_func2(x, y=10):
return x + y + 100
class MultiContainer:
@staticmethod
def func1(x, y=10):
return x + y
@staticmethod
def func2(x, y=10):
return x - y
container = MultiContainer()
target1 = (container, 'func1')
target2 = (container, 'func2')
hijack_function(target1, replacement_func1)
hijack_function(target2, replacement_func2)
self.assertEqual(container.func1(2, 3), 6)
self.assertEqual(container.func2(10, 3), 113)
restore_all_hijacked()
self.assertEqual(container.func1(2, 3), 5)
self.assertEqual(container.func2(10, 3), 7)
def test_global_functions_edge_cases(self):
"""测试全局函数的边界情况"""
restore_all_hijacked()
class TestEdgeCases(unittest.TestCase):
"""测试边界情况"""
def test_empty_container(self):
"""测试空容器"""
class EmptyContainer:
pass
container = EmptyContainer()
target = (container, 'nonexistent')
def replacement():
return "replaced"
with self.assertRaises(AttributeError):
hijack_function(target, replacement)
def test_none_values_in_wrapper(self):
"""测试包装器中的None值处理"""
def original_func(x=None):
return x or 0
def replacement_func(x=None):
return (x or 0) + 10
wrapper = _create_kwargs_wrapper(original_func, replacement_func)
result = wrapper(x=None)
self.assertEqual(result, 10)
result = wrapper()
self.assertEqual(result, 10)
def test_function_with_no_parameters(self):
"""测试无参数函数"""
def original_func():
return 42
def replacement_func():
return 84
wrapper = _create_kwargs_wrapper(original_func, replacement_func)
result = wrapper()
self.assertEqual(result, 84)
def test_function_with_various_arg_types(self):
"""测试包含各种参数类型的函数"""
def replacement_func(**kwargs):
args_val = kwargs.get('args', ())
kwargs_dict = kwargs.get('kwargs', {})
if args_val:
first_arg = args_val[0]
kw_arg_val = kwargs_dict.get('kw_arg', 1)
extra_val = kwargs_dict.get('extra', 0)
return first_arg * kw_arg_val * (sum([x for x in args_val if isinstance(x, (int, float))]) + 1) * (
extra_val + 1)
else:
return 0
wrapper = _create_kwargs_wrapper(original_func_with_args_kwargs, replacement_func)
result = wrapper(2, kw_arg=3, extra=5)
self.assertEqual(result, 108)
def test_inspect_signature_with_complex_function(self):
"""测试复杂函数的签名检查"""
wrapper = _create_kwargs_wrapper(
complex_func_for_test, replacement_func_for_complex_test
)
result = wrapper(42, "test")
self.assertEqual(result, "replaced-42-test")
result = wrapper(42)
self.assertEqual(result, "replaced-42-default")
if __name__ == '__main__':
unittest.main()