"""
Patcher导入相关功能测试模块
本模块测试patcher框架的导入相关功能,包括:
- skip_import: 跳过指定模块的导入(注册stub)
- inject_import: 注入缺失的导入
- with_imports: 延迟导入装饰器
测试设计原则:
- 使用mock模块模拟真实环境
- 直接加载patcher子模块,避免触发torch依赖
- 每个测试用例独立,通过setUp/tearDown管理测试环境
"""
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from functools import wraps as _functools_wraps
from typing import Dict, List
_project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
_patcher_dir = os.path.join(_project_root, "mx_driving", "patcher")
def _load_module_from_file(module_name: str, file_path: str):
"""Load a module directly from file without triggering parent package __init__.py."""
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
_patcher_logger_module = _load_module_from_file(
"mx_driving.patcher.patcher_logger",
os.path.join(_patcher_dir, "patcher_logger.py")
)
_reporting = _load_module_from_file(
"mx_driving.patcher.reporting",
os.path.join(_patcher_dir, "reporting.py")
)
_version_module = _load_module_from_file(
"mx_driving.patcher.version",
os.path.join(_patcher_dir, "version.py")
)
_patch_module = _load_module_from_file(
"mx_driving.patcher.patch",
os.path.join(_patcher_dir, "patch.py")
)
_patcher_module = _load_module_from_file(
"mx_driving.patcher.patcher",
os.path.join(_patcher_dir, "patcher.py")
)
Patcher = _patcher_module.Patcher
AtomicPatch = _patch_module.AtomicPatch
LegacyPatch = _patch_module.LegacyPatch
Patch = _patch_module.Patch
with_imports = _patch_module.with_imports
_patcher_logger = _patcher_logger_module.patcher_logger
class TestSkipImport(unittest.TestCase):
"""
skip_import功能测试
skip_import用于跳过指定模块的导入,通过注册stub模块到sys.modules。
这对于处理CUDA专用依赖(如flash_attn)非常有用。
"""
def setUp(self):
"""记录测试前的sys.modules状态"""
self._original_modules = set(sys.modules.keys())
def tearDown(self):
"""清理测试中添加的模块"""
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_skip_import_registers_stub(self):
"""
测试skip_import注册stub模块
Ground Truth:
- 调用skip_import后,模块应该存在于sys.modules中
- 模块应该是一个stub,可以安全访问任意属性
"""
patcher = Patcher()
patcher.skip_import("fake_cuda_module")
self.assertIn("fake_cuda_module", sys.modules)
module = sys.modules["fake_cuda_module"]
_ = module.some_function
_ = module.SomeClass
def test_skip_import_immediate_execution(self):
"""
测试skip_import立即执行
Ground Truth:
- skip_import应该在调用时立即注册stub
- 不需要等待apply()调用
"""
patcher = Patcher()
self.assertNotIn("immediate_test_module", sys.modules)
patcher.skip_import("immediate_test_module")
self.assertIn("immediate_test_module", sys.modules)
def test_skip_import_multiple_modules(self):
"""
测试skip_import支持多个模块
Ground Truth:
- 可以一次跳过多个模块
- 所有模块都应该被注册为stub
"""
patcher = Patcher()
patcher.skip_import("multi_test_a", "multi_test_b", "multi_test_c")
self.assertIn("multi_test_a", sys.modules)
self.assertIn("multi_test_b", sys.modules)
self.assertIn("multi_test_c", sys.modules)
def test_skip_import_submodule_access(self):
"""
测试skip_import的stub支持子模块访问
Ground Truth:
- stub应该支持子模块的导入语法
- from fake_module.submodule import something 应该工作
"""
patcher = Patcher()
patcher.skip_import("fake_parent_module")
module = sys.modules["fake_parent_module"]
submodule = module.submodule
self.assertIsNotNone(submodule)
def test_skip_import_already_imported_no_effect(self):
"""
测试skip_import对已导入模块无效
Ground Truth:
- 如果模块已经存在于sys.modules中,skip_import不应该覆盖它
"""
real_module = types.ModuleType("pre_existing_module")
real_module.real_attr = "real_value"
sys.modules["pre_existing_module"] = real_module
patcher = Patcher()
patcher.skip_import("pre_existing_module")
self.assertIs(sys.modules["pre_existing_module"], real_module)
self.assertEqual(sys.modules["pre_existing_module"].real_attr, "real_value")
def test_skip_import_chaining(self):
"""
测试skip_import支持链式调用
Ground Truth:
- skip_import应该返回self,支持链式调用
"""
patcher = Patcher()
result = patcher.skip_import("chain_test_1").skip_import("chain_test_2")
self.assertIs(result, patcher)
self.assertIn("chain_test_1", sys.modules)
self.assertIn("chain_test_2", sys.modules)
class TestInjectImport(unittest.TestCase):
"""
inject_import功能测试
inject_import用于将一个模块中的类/函数注入到另一个模块中。
这对于修复第三方代码中缺失的导出非常有用。
"""
def setUp(self):
"""创建测试模块"""
self.source_module = types.ModuleType("inject_source_module")
self.source_module.MyClass = type("MyClass", (), {"value": 42})
self.source_module.my_function = lambda x: x * 2
sys.modules["inject_source_module"] = self.source_module
self.target_module = types.ModuleType("inject_target_module")
self.target_module.__all__ = []
sys.modules["inject_target_module"] = self.target_module
def tearDown(self):
"""清理测试模块"""
for name in ["inject_source_module", "inject_target_module"]:
if name in sys.modules:
del sys.modules[name]
def test_inject_import_basic(self):
"""
测试基本的inject_import功能
Ground Truth:
- 类应该从source模块注入到target模块
- target模块应该可以访问注入的类
"""
patcher = Patcher()
patcher.inject_import(
"inject_source_module",
"MyClass",
"inject_target_module"
)
self.assertTrue(hasattr(self.target_module, "MyClass"))
self.assertEqual(self.target_module.MyClass.value, 42)
def test_inject_import_immediate_execution(self):
"""
测试inject_import立即执行
Ground Truth:
- inject_import应该在调用时立即执行注入
- 不需要等待apply()调用
"""
patcher = Patcher()
self.assertFalse(hasattr(self.target_module, "MyClass"))
patcher.inject_import(
"inject_source_module",
"MyClass",
"inject_target_module"
)
self.assertTrue(hasattr(self.target_module, "MyClass"))
def test_inject_import_adds_to_all(self):
"""
测试inject_import将类添加到__all__
Ground Truth:
- 如果target模块有__all__,注入的类名应该被添加进去
"""
patcher = Patcher()
patcher.inject_import(
"inject_source_module",
"MyClass",
"inject_target_module"
)
self.assertIn("MyClass", self.target_module.__all__)
def test_inject_import_function(self):
"""
测试inject_import可以注入函数
Ground Truth:
- 不仅类,函数也可以被注入
"""
patcher = Patcher()
patcher.inject_import(
"inject_source_module",
"my_function",
"inject_target_module"
)
self.assertTrue(hasattr(self.target_module, "my_function"))
self.assertEqual(self.target_module.my_function(5), 10)
def test_inject_import_nonexistent_class(self):
"""
测试inject_import处理不存在的类
Ground Truth:
- 如果source模块中不存在指定的类,应该静默失败
- 不应该抛出异常
"""
patcher = Patcher()
patcher.inject_import(
"inject_source_module",
"NonExistentClass",
"inject_target_module"
)
self.assertFalse(hasattr(self.target_module, "NonExistentClass"))
def test_inject_import_nonexistent_module(self):
"""
测试inject_import处理不存在的模块
Ground Truth:
- 如果source或target模块不存在,应该静默失败
- 不应该抛出异常
"""
patcher = Patcher()
patcher.inject_import(
"nonexistent_source",
"MyClass",
"inject_target_module"
)
patcher.inject_import(
"inject_source_module",
"MyClass",
"nonexistent_target"
)
def test_inject_import_no_duplicate(self):
"""
测试inject_import不会重复注入
Ground Truth:
- 如果target模块已经有同名属性,不应该覆盖
"""
self.target_module.MyClass = "original_value"
patcher = Patcher()
patcher.inject_import(
"inject_source_module",
"MyClass",
"inject_target_module"
)
self.assertEqual(self.target_module.MyClass, "original_value")
def test_inject_import_chaining(self):
"""
测试inject_import支持链式调用
Ground Truth:
- inject_import应该返回self,支持链式调用
"""
patcher = Patcher()
result = patcher.inject_import(
"inject_source_module",
"MyClass",
"inject_target_module"
).inject_import(
"inject_source_module",
"my_function",
"inject_target_module"
)
self.assertIs(result, patcher)
self.assertTrue(hasattr(self.target_module, "MyClass"))
self.assertTrue(hasattr(self.target_module, "my_function"))
class TestWithImports(unittest.TestCase):
"""
with_imports装饰器测试
with_imports用于延迟导入模块到函数的全局命名空间。
这允许replacement函数使用与原始函数相同的导入语法。
"""
def setUp(self):
"""创建测试模块"""
self.test_module = types.ModuleType("with_imports_test_module")
self.test_module.CONSTANT_A = 10
self.test_module.CONSTANT_B = 20
self.test_module.helper_func = lambda x: x * 3
sys.modules["with_imports_test_module"] = self.test_module
def tearDown(self):
"""清理测试模块"""
if "with_imports_test_module" in sys.modules:
del sys.modules["with_imports_test_module"]
def test_with_imports_basic(self):
"""
测试with_imports基本功能
Ground Truth:
- 装饰后的函数应该能访问注入的名称
- 函数应该正常执行并返回正确结果
"""
@with_imports(("with_imports_test_module", "CONSTANT_A", "CONSTANT_B"))
def my_func():
return CONSTANT_A + CONSTANT_B
result = my_func()
self.assertEqual(result, 30)
def test_with_imports_lazy_resolution(self):
"""
测试with_imports延迟解析
Ground Truth:
- 导入应该在第一次调用时才发生
- 装饰时不应该导入模块
"""
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def lazy_func():
return CONSTANT_A
result = lazy_func()
self.assertEqual(result, 10)
def test_with_imports_caching(self):
"""
测试with_imports缓存机制
Ground Truth:
- 第一次调用后,后续调用不应该重新导入
- 函数应该使用缓存的结果
"""
call_count = [0]
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def cached_func():
call_count[0] += 1
return CONSTANT_A
result1 = cached_func()
result2 = cached_func()
result3 = cached_func()
self.assertEqual(result1, 10)
self.assertEqual(result2, 10)
self.assertEqual(result3, 10)
self.assertEqual(call_count[0], 3)
def test_with_imports_multiple_specs(self):
"""
测试with_imports支持多个导入规格
Ground Truth:
- 可以从多个模块导入
- 所有导入的名称都应该可用
"""
other_module = types.ModuleType("other_test_module")
other_module.OTHER_CONST = 100
sys.modules["other_test_module"] = other_module
try:
@with_imports(
("with_imports_test_module", "CONSTANT_A"),
("other_test_module", "OTHER_CONST"),
)
def multi_import_func():
return CONSTANT_A + OTHER_CONST
result = multi_import_func()
self.assertEqual(result, 110)
finally:
del sys.modules["other_test_module"]
def test_with_imports_with_staticmethod(self):
"""
测试with_imports与@staticmethod配合使用
Ground Truth:
- @staticmethod应该放在@with_imports之前
- 静态方法应该正常工作
"""
class MyClass:
@staticmethod
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def static_method():
return CONSTANT_A * 2
result = MyClass.static_method()
self.assertEqual(result, 20)
def test_with_imports_with_classmethod(self):
"""
测试with_imports与@classmethod配合使用
Ground Truth:
- @classmethod应该放在@with_imports之前
- 类方法应该正常工作
"""
class MyClass:
multiplier = 3
@classmethod
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def class_method(cls):
return CONSTANT_A * cls.multiplier
result = MyClass.class_method()
self.assertEqual(result, 30)
def test_with_imports_with_arguments(self):
"""
测试with_imports装饰的函数可以接收参数
Ground Truth:
- 装饰后的函数应该正确传递参数
"""
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def func_with_args(x, y, z=1):
return CONSTANT_A + x + y + z
result = func_with_args(1, 2, z=3)
self.assertEqual(result, 16)
def test_with_imports_missing_name(self):
"""
测试with_imports处理不存在的名称
Ground Truth:
- 如果模块中不存在指定的名称,应该静默跳过
- 函数仍然可以执行(如果不使用该名称)
"""
@with_imports(("with_imports_test_module", "NONEXISTENT", "CONSTANT_A"))
def func_with_missing():
return CONSTANT_A
result = func_with_missing()
self.assertEqual(result, 10)
def test_with_imports_missing_module(self):
"""
测试with_imports处理不存在的模块
Ground Truth:
- 如果模块不存在,应该静默跳过
- 函数仍然可以执行(如果不使用该模块的名称)
"""
@with_imports(
("nonexistent_module", "SOMETHING"),
("with_imports_test_module", "CONSTANT_A"),
)
def func_with_missing_module():
return CONSTANT_A
result = func_with_missing_module()
self.assertEqual(result, 10)
def test_with_imports_function_attributes_preserved(self):
"""
测试with_imports保留函数属性
Ground Truth:
- 装饰后的函数应该保留原函数的__name__和__doc__
"""
@with_imports(("with_imports_test_module", "CONSTANT_A"))
def documented_func():
"""This is a docstring."""
return CONSTANT_A
self.assertEqual(documented_func.__name__, "documented_func")
self.assertEqual(documented_func.__doc__, "This is a docstring.")
class TestImportIntegration(unittest.TestCase):
"""
导入功能集成测试
测试skip_import、inject_import和apply的执行顺序和交互。
"""
def setUp(self):
"""记录测试前的sys.modules状态"""
self._original_modules = set(sys.modules.keys())
def tearDown(self):
"""清理测试中添加的模块"""
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_skip_then_inject_order(self):
"""
测试skip_import和inject_import的执行顺序
Ground Truth:
- skip_import和inject_import都应该在调用时立即执行
- 不需要等待apply()
"""
source = types.ModuleType("order_test_source")
source.MyClass = type("MyClass", (), {})
sys.modules["order_test_source"] = source
target = types.ModuleType("order_test_target")
sys.modules["order_test_target"] = target
patcher = Patcher()
patcher.skip_import("order_test_skip")
patcher.inject_import("order_test_source", "MyClass", "order_test_target")
self.assertIn("order_test_skip", sys.modules)
self.assertTrue(hasattr(target, "MyClass"))
def test_typical_usage_pattern(self):
"""
测试典型使用模式
Ground Truth:
- 典型用法: skip_import -> inject_import -> add patches -> apply
- 所有操作应该按预期工作
"""
source = types.ModuleType("typical_source")
source.Helper = type("Helper", (), {"value": 100})
sys.modules["typical_source"] = source
target = types.ModuleType("typical_target")
target.func = lambda x: x
sys.modules["typical_target"] = target
patcher = Patcher()
patcher.skip_import("cuda_only_module")
patcher.inject_import("typical_source", "Helper", "typical_target")
self.assertIn("cuda_only_module", sys.modules)
self.assertTrue(hasattr(target, "Helper"))
self.assertEqual(target.Helper.value, 100)
class TestStubModuleSpecialMethods(unittest.TestCase):
"""
_StubModule特殊方法测试
测试_StubModule的各种Python特殊方法,确保它能正确与Python导入机制配合。
这些测试覆盖了导致 TypeError: '_StubModule' object is not iterable 的场景。
"""
def setUp(self):
"""记录测试前的sys.modules状态"""
self._original_modules = set(sys.modules.keys())
def tearDown(self):
"""清理测试中添加的模块"""
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_stub_module_is_iterable(self):
"""
测试stub模块可迭代
Ground Truth:
- stub模块应该支持迭代操作
- 迭代结果应该为空列表
- 这是Python导入机制查找子模块时的必要条件
"""
patcher = Patcher()
patcher.skip_import("iterable_test_module")
module = sys.modules["iterable_test_module"]
items = list(module)
self.assertEqual(items, [])
iterator = iter(module)
self.assertEqual(list(iterator), [])
def test_stub_module_has_path(self):
"""
测试stub模块有__path__属性
Ground Truth:
- stub模块应该有__path__属性
- __path__应该是一个空列表
- 这是Python将模块视为包的必要条件
"""
patcher = Patcher()
patcher.skip_import("path_test_module")
module = sys.modules["path_test_module"]
self.assertTrue(hasattr(module, "__path__"))
self.assertIsInstance(module.__path__, list)
self.assertEqual(module.__path__, [])
def test_stub_module_has_all(self):
"""
测试stub模块有__all__属性
Ground Truth:
- stub模块应该有__all__属性
- __all__应该是一个空列表
- 这是"from module import *"语法的必要条件
"""
patcher = Patcher()
patcher.skip_import("all_test_module")
module = sys.modules["all_test_module"]
self.assertTrue(hasattr(module, "__all__"))
self.assertIsInstance(module.__all__, list)
self.assertEqual(module.__all__, [])
def test_stub_module_contains(self):
"""
测试stub模块支持in操作符
Ground Truth:
- stub模块应该支持"in"操作符
- 任何检查都应该返回False
"""
patcher = Patcher()
patcher.skip_import("contains_test_module")
module = sys.modules["contains_test_module"]
self.assertFalse("anything" in module)
self.assertFalse("submodule" in module)
def test_stub_module_len(self):
"""
测试stub模块支持len()
Ground Truth:
- stub模块应该支持len()调用
- 长度应该为0
"""
patcher = Patcher()
patcher.skip_import("len_test_module")
module = sys.modules["len_test_module"]
self.assertEqual(len(module), 0)
def test_stub_module_bool(self):
"""
测试stub模块的布尔值
Ground Truth:
- stub模块的布尔值应该为False
- 这表明它不是一个真实的模块
"""
patcher = Patcher()
patcher.skip_import("bool_test_module")
module = sys.modules["bool_test_module"]
self.assertFalse(bool(module))
def test_stub_module_repr(self):
"""
测试stub模块的字符串表示
Ground Truth:
- stub模块应该有可读的字符串表示
- 应该包含模块名称
"""
patcher = Patcher()
patcher.skip_import("repr_test_module")
module = sys.modules["repr_test_module"]
repr_str = repr(module)
self.assertIn("repr_test_module", repr_str)
self.assertIn("Stub", repr_str)
def test_stub_submodule_registered_in_sys_modules(self):
"""
测试stub子模块自动注册到sys.modules
Ground Truth:
- 访问stub.submodule时,子模块应该自动注册到sys.modules
- 这是Python导入机制能找到子模块的必要条件
- 解决 ModuleNotFoundError: No module named 'xxx.submodule' 问题
"""
patcher = Patcher()
patcher.skip_import("auto_register_test")
module = sys.modules["auto_register_test"]
submodule = module.submodule
self.assertIn("auto_register_test.submodule", sys.modules)
self.assertIs(sys.modules["auto_register_test.submodule"], submodule)
def test_stub_deep_submodule_registered(self):
"""
测试深层stub子模块自动注册到sys.modules
Ground Truth:
- 访问stub.level1.level2.level3时,所有层级都应该注册到sys.modules
"""
patcher = Patcher()
patcher.skip_import("deep_register_test")
module = sys.modules["deep_register_test"]
deep = module.level1.level2.level3
self.assertIn("deep_register_test.level1", sys.modules)
self.assertIn("deep_register_test.level1.level2", sys.modules)
self.assertIn("deep_register_test.level1.level2.level3", sys.modules)
def test_stub_submodule_iteration(self):
"""
测试stub子模块的迭代
Ground Truth:
- 通过属性访问获取的子模块也应该可迭代
- 这是导致原始bug的场景
"""
patcher = Patcher()
patcher.skip_import("parent_iter_test")
parent = sys.modules["parent_iter_test"]
submodule = parent.submodule
items = list(submodule)
self.assertEqual(items, [])
def test_stub_module_from_import_star(self):
"""
测试stub模块支持from ... import *语法
Ground Truth:
- "from stub_module import *"不应该抛出异常
- 由于__all__为空,不会导入任何内容
"""
patcher = Patcher()
patcher.skip_import("star_import_test")
module = sys.modules["star_import_test"]
self.assertTrue(hasattr(module, "__all__"))
self.assertEqual(module.__all__, [])
def test_stub_deep_nested_access(self):
"""
测试深层嵌套的stub访问
Ground Truth:
- 深层嵌套的属性访问应该都返回stub
- 每个层级都应该可迭代
"""
patcher = Patcher()
patcher.skip_import("deep_nested_test")
module = sys.modules["deep_nested_test"]
deep = module.level1.level2.level3.level4
self.assertEqual(list(module), [])
self.assertEqual(list(module.level1), [])
self.assertEqual(list(deep), [])
def test_stub_module_callable_returns_none(self):
"""
测试stub模块作为函数调用返回None
Ground Truth:
- stub模块可以被调用
- 调用结果应该是None
"""
patcher = Patcher()
patcher.skip_import("callable_test")
module = sys.modules["callable_test"]
result = module()
self.assertIsNone(result)
result = module(1, 2, 3, key="value")
self.assertIsNone(result)
def test_stub_attribute_callable(self):
"""
测试stub属性作为函数调用
Ground Truth:
- stub的任何属性都可以被调用
- 调用结果应该是None
"""
patcher = Patcher()
patcher.skip_import("attr_callable_test")
module = sys.modules["attr_callable_test"]
result = module.some_function(1, 2, 3)
self.assertIsNone(result)
result = module.SomeClass()
self.assertIsNone(result)
class TestRealWorldImportScenarios(unittest.TestCase):
"""
真实世界导入场景测试
模拟真实的CUDA依赖导入场景,确保skip_import能正确处理。
"""
def setUp(self):
"""记录测试前的sys.modules状态"""
self._original_modules = set(sys.modules.keys())
def tearDown(self):
"""清理测试中添加的模块"""
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_flash_attn_like_import(self):
"""
测试类似flash_attn的导入场景
Ground Truth:
- 模拟flash_attn的导入模式
- from flash_attn.flash_attn_interface import xxx 应该工作
"""
patcher = Patcher()
patcher.skip_import("flash_attn")
module = sys.modules["flash_attn"]
interface = module.flash_attn_interface
func = interface.flash_attn_varlen_kvpacked_func
self.assertIsNotNone(func)
result = func(None, None, None)
self.assertIsNone(result)
def test_flash_attn_real_import_statement(self):
"""
测试使用真实import语句导入flash_attn子模块
Ground Truth:
- 使用importlib.import_module模拟真实的import语句
- 子模块应该能被正确导入
- 这是解决 ModuleNotFoundError: No module named 'flash_attn.flash_attn_interface' 的关键测试
"""
import importlib
patcher = Patcher()
patcher.skip_import("flash_attn")
submodule = importlib.import_module("flash_attn.flash_attn_interface")
self.assertIsNotNone(submodule)
self.assertIn("flash_attn.flash_attn_interface", sys.modules)
func = submodule.flash_attn_varlen_kvpacked_func
self.assertIsNotNone(func)
def test_deep_submodule_real_import(self):
"""
测试使用真实import语句导入深层子模块
Ground Truth:
- 深层子模块也应该能被正确导入
"""
import importlib
patcher = Patcher()
patcher.skip_import("deep_import_test")
submodule = importlib.import_module("deep_import_test.level1.level2.level3")
self.assertIn("deep_import_test", sys.modules)
self.assertIn("deep_import_test.level1", sys.modules)
self.assertIn("deep_import_test.level1.level2", sys.modules)
self.assertIn("deep_import_test.level1.level2.level3", sys.modules)
def test_torch_scatter_like_import(self):
"""
测试类似torch_scatter的导入场景
Ground Truth:
- 模拟torch_scatter的导入模式
- from torch_scatter import scatter_max 应该工作
"""
patcher = Patcher()
patcher.skip_import("torch_scatter")
module = sys.modules["torch_scatter"]
scatter_max = module.scatter_max
self.assertIsNotNone(scatter_max)
def test_spconv_like_import(self):
"""
测试类似spconv的导入场景
Ground Truth:
- 模拟spconv的导入模式
- from spconv.pytorch import SparseConvTensor 应该工作
"""
patcher = Patcher()
patcher.skip_import("spconv")
module = sys.modules["spconv"]
pytorch = module.pytorch
SparseConvTensor = pytorch.SparseConvTensor
self.assertIsNotNone(SparseConvTensor)
def test_multiple_cuda_deps_skip(self):
"""
测试跳过多个CUDA依赖
Ground Truth:
- 可以同时跳过多个CUDA依赖
- 所有依赖都应该被正确stub
"""
patcher = Patcher()
patcher.skip_import(
"flash_attn",
"torch_scatter",
"spconv",
"cumm",
)
self.assertIn("flash_attn", sys.modules)
self.assertIn("torch_scatter", sys.modules)
self.assertIn("spconv", sys.modules)
self.assertIn("cumm", sys.modules)
for name in ["flash_attn", "torch_scatter", "spconv", "cumm"]:
module = sys.modules[name]
self.assertEqual(list(module), [])
class TestReplaceImport(unittest.TestCase):
"""
replace_import功能测试
测试Patcher.replace_import方法,用于替换模块导入。
"""
def setUp(self):
"""记录测试前的sys.modules状态"""
self._original_modules = set(sys.modules.keys())
def tearDown(self):
"""清理测试中添加的模块"""
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_replace_import_with_attrs(self):
"""
测试使用自定义属性替换模块
Ground Truth:
- 替换的模块应该出现在sys.modules中
- 模块的属性应该可以访问
"""
patcher = Patcher()
patcher.replace_import(
"custom_ops.my_op",
MyFunction=lambda x: x * 2,
)
self.assertIn("custom_ops.my_op", sys.modules)
self.assertEqual(sys.modules["custom_ops.my_op"].MyFunction(5), 10)
def test_replace_import_with_class(self):
"""
测试替换包含类的模块
Ground Truth:
- 模块中的类应该可以正常使用
"""
patcher = Patcher()
class DeformableAggregationFunction:
@staticmethod
def apply(x, y):
return x + y
patcher.replace_import(
"class_ops.aggregation",
DeformableAggregationFunction=DeformableAggregationFunction,
)
self.assertIn("class_ops.aggregation", sys.modules)
result = sys.modules["class_ops.aggregation"].DeformableAggregationFunction.apply(3, 4)
self.assertEqual(result, 7)
def test_replace_import_with_replacement_module(self):
"""
测试用另一个模块替换
Ground Truth:
- 可以用一个已存在的模块替换目标模块
- 替换后的模块应该有原模块的所有属性
"""
patcher = Patcher()
source = types.ModuleType("replacement_source")
source.VALUE = 42
source.func = lambda x: x * 3
sys.modules["replacement_source"] = source
patcher.replace_import(
"target_to_replace",
"replacement_source",
)
self.assertIn("target_to_replace", sys.modules)
self.assertEqual(sys.modules["target_to_replace"].VALUE, 42)
self.assertEqual(sys.modules["target_to_replace"].func(5), 15)
def test_replace_import_with_replacement_and_override(self):
"""
测试替换模块并覆盖特定属性
Ground Truth:
- 可以用另一个模块替换,同时覆盖特定属性
"""
patcher = Patcher()
source = types.ModuleType("override_source")
source.VALUE = 42
source.func = lambda x: x * 3
sys.modules["override_source"] = source
patcher.replace_import(
"override_target",
"override_source",
VALUE=100,
)
self.assertIn("override_target", sys.modules)
self.assertEqual(sys.modules["override_target"].VALUE, 100)
self.assertEqual(sys.modules["override_target"].func(5), 15)
def test_replace_import_skip_if_exists(self):
"""
测试已存在的模块不会被覆盖
Ground Truth:
- 如果模块已存在,replace_import应该跳过
"""
patcher = Patcher()
existing_module = types.ModuleType("existing_ops.op")
existing_module.original = True
sys.modules["existing_ops.op"] = existing_module
patcher.replace_import("existing_ops.op", original=False)
self.assertTrue(sys.modules["existing_ops.op"].original)
def test_replace_import_skip_if_exists_logs_warning(self):
"""
replace_import should warn with guidance when target module is already loaded.
"""
import logging
patcher = Patcher()
existing_module = types.ModuleType("existing_ops.warn")
sys.modules["existing_ops.warn"] = existing_module
warnings_logged = []
class CapHandler(logging.Handler):
def emit(self, record):
if record.levelno >= logging.WARNING:
warnings_logged.append(record.getMessage())
handler = CapHandler()
_patcher_logger._logger.addHandler(handler)
try:
patcher.replace_import("existing_ops.warn", VALUE=1)
finally:
_patcher_logger._logger.removeHandler(handler)
self.assertTrue(
any("replace_import skipped for existing_ops.warn" in msg for msg in warnings_logged)
)
self.assertTrue(
any("Call replace_import() before importing the target module" in msg
for msg in warnings_logged)
)
def test_replace_import_returns_self(self):
"""
测试replace_import返回self以支持链式调用
Ground Truth:
- replace_import()应该返回patcher实例本身
"""
patcher = Patcher()
result = patcher.replace_import("chain_test.op", value=1)
self.assertIs(result, patcher)
def test_replace_import_chaining(self):
"""
测试replace_import的链式调用
Ground Truth:
- 可以链式调用多个replace_import
"""
patcher = Patcher()
patcher.replace_import("chain.op1", value=1).replace_import("chain.op2", value=2)
self.assertIn("chain.op1", sys.modules)
self.assertIn("chain.op2", sys.modules)
self.assertEqual(sys.modules["chain.op1"].value, 1)
self.assertEqual(sys.modules["chain.op2"].value, 2)
def test_replace_import_importable(self):
"""
测试替换的模块可以通过import语句导入
Ground Truth:
- 替换后的模块应该可以通过importlib.import_module导入
"""
import importlib
patcher = Patcher()
patcher.replace_import("importable_ops.custom", CONSTANT=123)
imported = importlib.import_module("importable_ops.custom")
self.assertEqual(imported.CONSTANT, 123)
def test_deformable_aggregation_pattern(self):
"""
测试类似DeformableAggregation的使用模式
Ground Truth:
- 模拟DiffusionDrive中DeformableAggregationOp的使用场景
"""
import importlib
patcher = Patcher()
class DeformableAggregationFunction:
@staticmethod
def apply(*args, **kwargs):
return sum(args) if args else None
patcher.replace_import(
"projects.mmdet3d_plugin.ops.deformable_aggregation",
DeformableAggregationFunction=DeformableAggregationFunction,
)
imported = importlib.import_module("projects.mmdet3d_plugin.ops.deformable_aggregation")
result = imported.DeformableAggregationFunction.apply(1, 2, 3)
self.assertEqual(result, 6)
class TestWithImportsEnhancements(unittest.TestCase):
"""Tests for with_imports improvements from deep research."""
def setUp(self):
self.test_module = types.ModuleType("wi_enhance_test")
self.test_module.VALUE = 42
sys.modules["wi_enhance_test"] = self.test_module
def tearDown(self):
sys.modules.pop("wi_enhance_test", None)
sys.modules.pop("wi_decorator_mod", None)
def test_import_failure_logs_warning(self):
"""ImportError should log at WARNING level, not DEBUG."""
import logging
warnings_logged = []
class CapHandler(logging.Handler):
def emit(self, record):
if record.levelno >= logging.WARNING:
warnings_logged.append(record.getMessage())
handler = CapHandler()
_patcher_logger._logger.addHandler(handler)
try:
@with_imports("nonexistent_module_xyz123")
def func():
pass
func()
finally:
_patcher_logger._logger.removeHandler(handler)
self.assertTrue(any("nonexistent_module_xyz123" in w for w in warnings_logged),
"ImportError should produce a WARNING log")
def test_annotations_preserved(self):
"""Function annotations should be preserved after decoration."""
@with_imports(("wi_enhance_test", "VALUE"))
def annotated_func(x: int) -> int:
return VALUE + x
result = annotated_func(1)
self.assertEqual(result, 43)
self.assertIn('x', annotated_func.__annotations__)
def test_stacking_warning(self):
"""Stacking multiple @with_imports should produce a warning."""
import logging
warnings_logged = []
class CapHandler(logging.Handler):
def emit(self, record):
if record.levelno >= logging.WARNING:
warnings_logged.append(record.getMessage())
handler = CapHandler()
_patcher_logger._logger.addHandler(handler)
try:
@with_imports("wi_enhance_test")
@with_imports(("wi_enhance_test", "VALUE"))
def stacked_func():
return VALUE
finally:
_patcher_logger._logger.removeHandler(handler)
self.assertTrue(any("stacking" in w.lower() for w in warnings_logged),
"Stacking @with_imports should produce a warning")
def test_apply_decorators_basic(self):
"""apply_decorators should lazily apply decorators from target modules."""
dec_mod = types.ModuleType("wi_decorator_mod")
def my_decorator(multiplier=1):
def wrapper(func):
@_functools_wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs) * multiplier
return inner
return wrapper
dec_mod.my_decorator = my_decorator
sys.modules["wi_decorator_mod"] = dec_mod
@with_imports(
("wi_enhance_test", "VALUE"),
apply_decorators=[
("wi_decorator_mod.my_decorator", {"multiplier": 3})
]
)
def func():
return VALUE
result = func()
self.assertEqual(result, 42 * 3,
"apply_decorators should wrap the resolved function")
def test_at_prefix_decorator_with_kwargs(self):
"""@ string expression decorator with kwargs."""
dec_mod = types.ModuleType("wi_decorator_mod")
def my_decorator(multiplier=1):
def wrapper(func):
@_functools_wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs) * multiplier
return inner
return wrapper
dec_mod.my_decorator = my_decorator
sys.modules["wi_decorator_mod"] = dec_mod
@with_imports(
("wi_enhance_test", "VALUE"),
("wi_decorator_mod", "my_decorator"),
"@my_decorator(multiplier=5)",
)
def func():
return VALUE
self.assertEqual(func(), 42 * 5)
def test_at_prefix_decorator_no_args(self):
"""@ string expression decorator without kwargs."""
dec_mod = types.ModuleType("wi_decorator_mod")
def double(func):
@_functools_wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs) * 2
return inner
dec_mod.double = double
sys.modules["wi_decorator_mod"] = dec_mod
@with_imports(
("wi_enhance_test", "VALUE"),
("wi_decorator_mod", "double"),
"@double",
)
def func():
return VALUE
self.assertEqual(func(), 42 * 2)
def test_apply_decorators_missing_graceful(self):
"""apply_decorators with missing decorator should not crash."""
@with_imports(
("wi_enhance_test", "VALUE"),
apply_decorators=[
("nonexistent_module.decorator", {"arg": 1})
]
)
def func():
return VALUE
result = func()
self.assertEqual(result, 42)
def test_with_imports_decorated_flag(self):
"""Decorated functions should have _with_imports_decorated flag."""
@with_imports("wi_enhance_test")
def func():
pass
self.assertTrue(getattr(func, '_with_imports_decorated', False))
class TestInjectImportStubModuleFix(unittest.TestCase):
"""
Tests for inject_import() fix when target module is a _StubModule.
The old implementation used hasattr() which always returns True for _StubModule
(since __getattr__ returns a new stub for any non-private attribute).
The fix uses `class_name in vars(tgt_mod)` instead.
"""
def setUp(self):
self._original_modules = set(sys.modules.keys())
def tearDown(self):
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_inject_into_stub_target_succeeds(self):
"""
inject_import should succeed when target is a _StubModule.
Previously, hasattr(stub, 'MyClass') returned True because
_StubModule.__getattr__ returns a new stub, so injection was skipped.
"""
patcher = Patcher()
source = types.ModuleType("ii_stub_test.source")
source.MyClass = type("MyClass", (), {"value": 42})
sys.modules["ii_stub_test.source"] = source
patcher.skip_import("ii_stub_test.stub_target")
self.assertIn("ii_stub_test.stub_target", sys.modules)
patcher.inject_import("ii_stub_test.source", "MyClass", "ii_stub_test.stub_target")
target = sys.modules["ii_stub_test.stub_target"]
self.assertTrue(hasattr(target, "MyClass"))
self.assertEqual(target.MyClass.value, 42)
def test_inject_into_normal_module_still_skips_existing(self):
"""
inject_import should still skip if target already has the attribute
(for real modules, not stubs).
"""
patcher = Patcher()
source = types.ModuleType("ii_normal_test.source")
source.MyClass = type("MyClass", (), {"value": 42})
sys.modules["ii_normal_test.source"] = source
target = types.ModuleType("ii_normal_test.target")
target.MyClass = type("MyClass", (), {"value": 999})
sys.modules["ii_normal_test.target"] = target
patcher.inject_import("ii_normal_test.source", "MyClass", "ii_normal_test.target")
self.assertEqual(sys.modules["ii_normal_test.target"].MyClass.value, 999)
class TestReplaceImportImprovements(unittest.TestCase):
"""
Tests for replace_import() improvements:
- Parent package auto-registration
- Metadata preservation (__name__, __package__)
- Summary logging
- New API: base_module, exports, replace_with.module()
"""
def setUp(self):
self._original_modules = set(sys.modules.keys())
def tearDown(self):
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_parent_packages_auto_created(self):
"""
replace_import should auto-create parent packages in sys.modules.
This fixes the issue where `import pkg.mod` and `from pkg import mod`
failed because parent packages were not registered.
"""
patcher = Patcher()
patcher.replace_import("ri_parent.pkg.mod", VALUE=42)
self.assertIn("ri_parent", sys.modules)
self.assertIn("ri_parent.pkg", sys.modules)
self.assertIn("ri_parent.pkg.mod", sys.modules)
def test_parent_package_attribute_chain(self):
"""
Parent packages should have child attributes set for `from pkg import mod`.
"""
patcher = Patcher()
patcher.replace_import("ri_chain.sub.leaf", VALUE=99)
parent = sys.modules["ri_chain"]
self.assertTrue(hasattr(parent, "sub"))
sub = sys.modules["ri_chain.sub"]
self.assertTrue(hasattr(sub, "leaf"))
def test_from_pkg_import_mod_works(self):
"""
After replace_import, `from pkg import mod` style should work.
"""
patcher = Patcher()
patcher.replace_import("ri_from_test.pkg.mod", CONSTANT=55)
mod = sys.modules["ri_from_test.pkg.mod"]
self.assertEqual(mod.CONSTANT, 55)
pkg = sys.modules["ri_from_test.pkg"]
self.assertIs(getattr(pkg, "mod", None), mod)
def test_importlib_import_still_works(self):
"""
importlib.import_module should continue to work.
"""
import importlib
patcher = Patcher()
patcher.replace_import("ri_importlib.ops.custom", VALUE=77)
imported = importlib.import_module("ri_importlib.ops.custom")
self.assertEqual(imported.VALUE, 77)
def test_real_parent_packages_remain_importable(self):
"""
replace_import should not replace real filesystem packages with empty stubs.
Regression target:
- replace_import("pkg.sub.ops.leaf", ...)
- later import pkg.sub.apis should still work
"""
import importlib
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
root = os.path.join(temp_dir, "ri_real_parent")
subpkg = os.path.join(root, "subpkg")
apis = os.path.join(subpkg, "apis")
os.makedirs(apis)
with open(os.path.join(root, "__init__.py"), "w", encoding="utf-8") as f:
f.write("")
with open(os.path.join(subpkg, "__init__.py"), "w", encoding="utf-8") as f:
f.write("")
with open(os.path.join(apis, "__init__.py"), "w", encoding="utf-8") as f:
f.write("VALUE = 123\n")
sys.path.insert(0, temp_dir)
try:
patcher.replace_import("ri_real_parent.subpkg.ops.custom", VALUE=42)
parent_pkg = importlib.import_module("ri_real_parent.subpkg")
apis_pkg = importlib.import_module("ri_real_parent.subpkg.apis")
self.assertEqual(apis_pkg.VALUE, 123)
self.assertTrue(hasattr(parent_pkg, "__path__"))
self.assertNotEqual(list(parent_pkg.__path__), [])
finally:
sys.path.remove(temp_dir)
def test_namespace_parent_packages_remain_importable(self):
"""
replace_import should preserve namespace-package parents as real packages.
Regression target:
- top-level package directory exists without __init__.py
- later sibling imports under that namespace package should still work
"""
import importlib
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
root = os.path.join(temp_dir, "ri_ns_parent")
subpkg = os.path.join(root, "subpkg")
apis = os.path.join(subpkg, "apis")
os.makedirs(apis)
with open(os.path.join(subpkg, "__init__.py"), "w", encoding="utf-8") as f:
f.write("")
with open(os.path.join(apis, "__init__.py"), "w", encoding="utf-8") as f:
f.write("VALUE = 456\n")
sys.path.insert(0, temp_dir)
try:
patcher.replace_import("ri_ns_parent.subpkg.ops.custom", VALUE=42)
ns_pkg = importlib.import_module("ri_ns_parent")
apis_pkg = importlib.import_module("ri_ns_parent.subpkg.apis")
self.assertEqual(apis_pkg.VALUE, 456)
self.assertTrue(hasattr(ns_pkg, "__path__"))
self.assertNotEqual(list(ns_pkg.__path__), [])
finally:
sys.path.remove(temp_dir)
def test_single_segment_no_parent_needed(self):
"""
Single-segment module paths should work without parent creation.
"""
patcher = Patcher()
patcher.replace_import("ri_single_seg", VALUE=1)
self.assertIn("ri_single_seg", sys.modules)
self.assertEqual(sys.modules["ri_single_seg"].VALUE, 1)
def test_metadata_preserved_with_replacement_module(self):
"""
When using a replacement module, __name__ and __package__ should
reflect the target path, not the replacement source.
"""
patcher = Patcher()
source = types.ModuleType("ri_meta_source")
source.VALUE = 42
sys.modules["ri_meta_source"] = source
patcher.replace_import("ri_meta.target.mod", "ri_meta_source")
target = sys.modules["ri_meta.target.mod"]
self.assertEqual(target.__name__, "ri_meta.target.mod")
self.assertEqual(target.__package__, "ri_meta.target")
self.assertEqual(target.VALUE, 42)
def test_base_module_keyword(self):
"""
base_module= keyword should work as preferred replacement for positional arg.
"""
patcher = Patcher()
source = types.ModuleType("ri_bm_source")
source.func = lambda: "hello"
sys.modules["ri_bm_source"] = source
patcher.replace_import("ri_bm.target", base_module="ri_bm_source")
target = sys.modules["ri_bm.target"]
self.assertEqual(target.func(), "hello")
def test_replacement_and_base_module_conflict(self):
"""
Passing both positional replacement and base_module= should raise ValueError.
"""
patcher = Patcher()
source = types.ModuleType("ri_conflict_source")
sys.modules["ri_conflict_source"] = source
with self.assertRaises(ValueError):
patcher.replace_import("ri_conflict.target", "ri_conflict_source",
base_module="ri_conflict_source")
def test_exports_keyword(self):
"""
exports= should work as the preferred way to declare module attributes.
"""
patcher = Patcher()
patcher.replace_import(
"ri_exports.ops",
exports={"MyFunc": lambda: 42, "MyClass": int},
)
target = sys.modules["ri_exports.ops"]
self.assertEqual(target.MyFunc(), 42)
self.assertIs(target.MyClass, int)
def test_exports_and_attrs_conflict(self):
"""
If same key appears in both exports= and **attrs, should raise ValueError.
"""
patcher = Patcher()
with self.assertRaises(ValueError):
patcher.replace_import(
"ri_exp_conflict.ops",
exports={"VALUE": 1},
VALUE=2,
)
def test_base_module_with_exports(self):
"""
base_module + exports should use base as template and override with exports.
"""
patcher = Patcher()
source = types.ModuleType("ri_base_exp_source")
source.VALUE = 42
source.func = lambda: "original"
sys.modules["ri_base_exp_source"] = source
patcher.replace_import(
"ri_base_exp.target",
base_module="ri_base_exp_source",
exports={"func": lambda: "overridden"},
)
target = sys.modules["ri_base_exp.target"]
self.assertEqual(target.VALUE, 42)
self.assertEqual(target.func(), "overridden")
def test_replace_with_module_string(self):
"""
replace_with.module("new.module") should work as replacement spec.
"""
replace_with = _patcher_module.replace_with
patcher = Patcher()
source = types.ModuleType("ri_rw_source")
source.VALUE = 99
sys.modules["ri_rw_source"] = source
patcher.replace_import("ri_rw.target", replace_with.module("ri_rw_source"))
target = sys.modules["ri_rw.target"]
self.assertEqual(target.VALUE, 99)
def test_replace_with_module_exports(self):
"""
replace_with.module(Foo=Bar) should create a module with given exports.
"""
replace_with = _patcher_module.replace_with
patcher = Patcher()
patcher.replace_import(
"ri_rw_exp.ops",
replace_with.module(MyFunc=lambda: 42),
)
target = sys.modules["ri_rw_exp.ops"]
self.assertEqual(target.MyFunc(), 42)
def test_replace_with_module_base_and_override(self):
"""
replace_with.module("base", Foo=Bar) should use base module + override.
"""
replace_with = _patcher_module.replace_with
patcher = Patcher()
source = types.ModuleType("ri_rw_bo_source")
source.VALUE = 42
source.func = lambda: "original"
sys.modules["ri_rw_bo_source"] = source
patcher.replace_import(
"ri_rw_bo.target",
replace_with.module("ri_rw_bo_source", func=lambda: "override"),
)
target = sys.modules["ri_rw_bo.target"]
self.assertEqual(target.VALUE, 42)
self.assertEqual(target.func(), "override")
def test_replace_import_in_summary(self):
"""
replace_import should record entries for summary output.
"""
patcher = Patcher()
_patcher_logger._replaced_imports.clear()
patcher.replace_import("ri_summary.test", VALUE=1)
self.assertTrue(len(_patcher_logger._replaced_imports) > 0)
self.assertTrue(any("ri_summary.test" in desc
for desc in _patcher_logger._replaced_imports))
def test_legacy_positional_replacement_still_works(self):
"""
Legacy patcher.replace_import("old", "new") should still work.
"""
patcher = Patcher()
source = types.ModuleType("ri_legacy_source")
source.VALUE = 42
sys.modules["ri_legacy_source"] = source
patcher.replace_import("ri_legacy.target", "ri_legacy_source")
target = sys.modules["ri_legacy.target"]
self.assertEqual(target.VALUE, 42)
def test_legacy_attrs_still_works(self):
"""
Legacy patcher.replace_import("old", Foo=Bar) should still work.
"""
patcher = Patcher()
patcher.replace_import("ri_legacy_attrs.ops", MyFunc=lambda: 42)
target = sys.modules["ri_legacy_attrs.ops"]
self.assertEqual(target.MyFunc(), 42)
class TestSkipImportParentPackagePreservation(unittest.TestCase):
"""
Regression tests for skip_import() parent-package poisoning.
skip_import("pkg.sub.leaf") should not replace real parent packages
such as pkg or pkg.sub with _StubModule placeholders.
"""
def setUp(self):
self._original_modules = set(sys.modules.keys())
def tearDown(self):
current_modules = set(sys.modules.keys())
for module_name in current_modules - self._original_modules:
if module_name in sys.modules:
del sys.modules[module_name]
def test_real_parent_package_not_stubbed(self):
"""
Real filesystem-backed parents should remain real importable packages.
"""
import importlib
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
root = os.path.join(temp_dir, "si_real_parent")
subpkg = os.path.join(root, "subpkg")
sibling = os.path.join(root, "sibling")
os.makedirs(subpkg)
os.makedirs(sibling)
with open(os.path.join(root, "__init__.py"), "w", encoding="utf-8") as f:
f.write("VALUE = 123\n")
with open(os.path.join(subpkg, "__init__.py"), "w", encoding="utf-8") as f:
f.write("SUB_VALUE = 456\n")
with open(os.path.join(sibling, "__init__.py"), "w", encoding="utf-8") as f:
f.write("SIBLING = 789\n")
sys.path.insert(0, temp_dir)
try:
patcher.skip_import("si_real_parent.subpkg.blocked_leaf")
parent_pkg = importlib.import_module("si_real_parent")
sub_pkg = importlib.import_module("si_real_parent.subpkg")
sibling_pkg = importlib.import_module("si_real_parent.sibling")
self.assertEqual(parent_pkg.VALUE, 123)
self.assertEqual(sub_pkg.SUB_VALUE, 456)
self.assertEqual(sibling_pkg.SIBLING, 789)
self.assertFalse(type(parent_pkg).__name__.startswith("_StubModule"))
self.assertFalse(type(sub_pkg).__name__.startswith("_StubModule"))
finally:
sys.path.remove(temp_dir)
def test_namespace_parent_package_not_stubbed(self):
"""
Namespace-package parents should keep working for sibling imports.
"""
import importlib
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
root = os.path.join(temp_dir, "si_ns_parent")
subpkg = os.path.join(root, "subpkg")
sibling = os.path.join(root, "sibling")
os.makedirs(subpkg)
os.makedirs(sibling)
with open(os.path.join(subpkg, "__init__.py"), "w", encoding="utf-8") as f:
f.write("SUB_VALUE = 456\n")
with open(os.path.join(sibling, "__init__.py"), "w", encoding="utf-8") as f:
f.write("SIBLING = 789\n")
sys.path.insert(0, temp_dir)
try:
patcher.skip_import("si_ns_parent.subpkg.blocked_leaf")
ns_parent = importlib.import_module("si_ns_parent")
sibling_pkg = importlib.import_module("si_ns_parent.sibling")
self.assertTrue(hasattr(ns_parent, "__path__"))
self.assertEqual(sibling_pkg.SIBLING, 789)
self.assertFalse(type(ns_parent).__name__.startswith("_StubModule"))
finally:
sys.path.remove(temp_dir)
def test_skip_import_does_not_execute_real_parent_init(self):
"""
skip_import should not import real parent packages during stub registration.
This protects startup paths where a deep blocked import lives under a real
plugin package whose __init__.py has heavyweight or incompatible imports.
"""
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
root = os.path.join(temp_dir, "si_no_init_import")
subpkg = os.path.join(root, "subpkg")
os.makedirs(subpkg)
with open(os.path.join(root, "__init__.py"), "w", encoding="utf-8") as f:
f.write("raise RuntimeError('parent __init__ should not execute during skip_import')\n")
with open(os.path.join(subpkg, "__init__.py"), "w", encoding="utf-8") as f:
f.write("")
sys.path.insert(0, temp_dir)
try:
patcher.skip_import("si_no_init_import.subpkg.blocked_leaf")
self.assertIn("si_no_init_import.subpkg.blocked_leaf", sys.modules)
finally:
sys.path.remove(temp_dir)
class TestBootstrapPatchesBeforeCollect(unittest.TestCase):
"""
Regression tests for compatibility patches that must apply before collect.
Some Patch.patches() implementations eagerly import third-party modules while
building AtomicPatch objects. Compatibility patches such as numpy alias
restoration must therefore be able to run before later Patch classes are
collected.
"""
def setUp(self):
self._original_modules = dict(sys.modules)
def tearDown(self):
current_modules = set(sys.modules.keys())
original_modules = set(self._original_modules.keys())
for module_name in current_modules - original_modules:
sys.modules.pop(module_name, None)
for module_name in original_modules:
sys.modules[module_name] = self._original_modules[module_name]
def test_apply_before_collect_allows_eager_import_patch_classes(self):
"""
apply_before_collect patches should run before later Patch.patches()
methods trigger import-time compatibility checks.
"""
import importlib
fake_numpy = types.ModuleType("numpy")
sys.modules["numpy"] = fake_numpy
class FakeNumpyCompat(Patch):
name = "fake_numpy_compat"
apply_before_collect = True
@staticmethod
def _restore_alias(np_module, _options):
if hasattr(np_module, "int"):
raise AttributeError("np.int already exists")
np_module.int = int
@classmethod
def patches(cls, options=None):
return [
LegacyPatch(cls._restore_alias, target_module="numpy", options=options),
]
class EagerImportPatch(Patch):
name = "eager_import_patch"
@classmethod
def patches(cls, options=None):
importlib.import_module("bootstrap_pkg.consumer")
return [
AtomicPatch("bootstrap_pkg.consumer.VALUE", 2),
]
patcher = Patcher()
with tempfile.TemporaryDirectory() as temp_dir:
pkg_dir = os.path.join(temp_dir, "bootstrap_pkg")
os.makedirs(pkg_dir)
with open(os.path.join(pkg_dir, "__init__.py"), "w", encoding="utf-8") as f:
f.write("")
with open(os.path.join(pkg_dir, "consumer.py"), "w", encoding="utf-8") as f:
f.write(
"import numpy as np\n"
"if not hasattr(np, 'int'):\n"
" raise RuntimeError('np.int missing during collection')\n"
"VALUE = 1\n"
)
sys.path.insert(0, temp_dir)
try:
patcher.add(FakeNumpyCompat, EagerImportPatch)
patcher.apply()
consumer = importlib.import_module("bootstrap_pkg.consumer")
self.assertIs(fake_numpy.int, int)
self.assertEqual(consumer.VALUE, 2)
finally:
sys.path.remove(temp_dir)
if __name__ == "__main__":
unittest.main()