"""
Legacy patcher API compatibility tests.
This module tests backward compatibility with the old patcher API (pre-87ec1895):
- LegacyPatchWrapper (Patch(func) syntax)
- LegacyPatcherBuilder (PatcherBuilder class)
- Old-style patch functions (msda, dc, mdc, batch_matmul, etc.)
- Mixed usage of old and new APIs
- Delegation to new Patcher implementation
"""
import importlib.util
import io
import os
import sys
import types
import unittest
from typing import Dict, List
from types import ModuleType
from unittest.mock import MagicMock, patch
_project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
_patcher_dir = os.path.join(_project_root, "mx_driving", "patcher")
def _load_module_from_file(module_name: str, file_path: str):
"""Load a module directly from file without triggering parent package __init__.py."""
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
_patcher_logger = _load_module_from_file(
"mx_driving.patcher.patcher_logger",
os.path.join(_patcher_dir, "patcher_logger.py")
)
_reporting = _load_module_from_file(
"mx_driving.patcher.reporting",
os.path.join(_patcher_dir, "reporting.py")
)
_version_module = _load_module_from_file(
"mx_driving.patcher.version",
os.path.join(_patcher_dir, "version.py")
)
_patch_module = _load_module_from_file(
"mx_driving.patcher.patch",
os.path.join(_patcher_dir, "patch.py")
)
_patcher_module = _load_module_from_file(
"mx_driving.patcher.patcher",
os.path.join(_patcher_dir, "patcher.py")
)
_legacy_module = _load_module_from_file(
"mx_driving.patcher.legacy",
os.path.join(_patcher_dir, "legacy.py")
)
_patcher_init_module = _load_module_from_file(
"mx_driving.patcher",
os.path.join(_patcher_dir, "__init__.py")
)
AtomicPatch = _patch_module.AtomicPatch
BasePatch = _patch_module.BasePatch
LegacyPatch = _patch_module.LegacyPatch
Patch = _patch_module.Patch
Patcher = _patcher_module.Patcher
PatchResult = _reporting.PatchResult
PatchStatus = _reporting.PatchStatus
LegacyPatchWrapper = _legacy_module.LegacyPatchWrapper
LegacyPatcherBuilder = _legacy_module.LegacyPatcherBuilder
PatcherBuilder = _legacy_module.PatcherBuilder
class TestLegacyPatchWrapper(unittest.TestCase):
"""
Test LegacyPatchWrapper class.
LegacyPatchWrapper wraps old-style patch functions with signature (module, options).
"""
def setUp(self):
"""Reset migration warning flag before each test."""
LegacyPatchWrapper._migration_warning_shown = False
def test_basic_wrapper(self):
"""Test basic LegacyPatchWrapper creation."""
def my_patch(module, options):
pass
wrapper = LegacyPatchWrapper(my_patch)
self.assertEqual(wrapper.name, "my_patch")
self.assertEqual(wrapper.func, my_patch)
self.assertEqual(wrapper.options, {})
self.assertEqual(wrapper.priority, 0)
self.assertFalse(wrapper.is_applied)
def test_wrapper_with_options(self):
"""Test LegacyPatchWrapper with options."""
def my_patch(module, options):
pass
wrapper = LegacyPatchWrapper(my_patch, options={'key': 'value'})
self.assertEqual(wrapper.options, {'key': 'value'})
def test_wrapper_with_priority(self):
"""Test LegacyPatchWrapper with priority."""
def patch1(module, options):
pass
def patch2(module, options):
pass
wrapper1 = LegacyPatchWrapper(patch1, priority=10)
wrapper2 = LegacyPatchWrapper(patch2, priority=5)
self.assertTrue(wrapper2 < wrapper1)
def test_wrapper_sorting(self):
"""Test that wrappers can be sorted by priority."""
def patch1(module, options):
pass
def patch2(module, options):
pass
def patch3(module, options):
pass
wrappers = [
LegacyPatchWrapper(patch1, priority=10),
LegacyPatchWrapper(patch2, priority=5),
LegacyPatchWrapper(patch3, priority=15),
]
sorted_wrappers = sorted(wrappers)
self.assertEqual(sorted_wrappers[0].name, "patch2")
self.assertEqual(sorted_wrappers[1].name, "patch1")
self.assertEqual(sorted_wrappers[2].name, "patch3")
class TestPatchMetaclass(unittest.TestCase):
"""
Test Patch metaclass for automatic old/new style detection.
The Patch class uses a metaclass that detects:
- Patch(func) -> returns LegacyPatchWrapper (old style)
- class MyPatch(Patch): ... -> normal class inheritance (new style)
"""
def setUp(self):
"""Reset migration warning flag before each test."""
LegacyPatchWrapper._migration_warning_shown = False
def test_patch_with_callable_returns_wrapper(self):
"""Test that Patch(func) returns LegacyPatchWrapper."""
def my_patch(module, options):
pass
result = Patch(my_patch)
self.assertIsInstance(result, LegacyPatchWrapper)
self.assertEqual(result.name, "my_patch")
def test_patch_subclass_works_normally(self):
"""Test that subclassing Patch works normally."""
class MyPatch(Patch):
name = "my_patch"
@classmethod
def patches(cls, options=None):
return []
self.assertTrue(isinstance(MyPatch, type))
self.assertTrue(issubclass(MyPatch, Patch))
self.assertEqual(MyPatch.name, "my_patch")
def test_patch_with_options(self):
"""Test Patch(func, options=...) passes options to wrapper."""
def my_patch(module, options):
pass
result = Patch(my_patch, options={'key': 'value'})
self.assertIsInstance(result, LegacyPatchWrapper)
self.assertEqual(result.options, {'key': 'value'})
class TestLegacyPatcherBuilder(unittest.TestCase):
"""
Test LegacyPatcherBuilder class.
LegacyPatcherBuilder provides the old PatcherBuilder API and
delegates to the new Patcher implementation.
"""
def setUp(self):
"""Create mock module for testing."""
self.mock_module = types.ModuleType('legacy_builder_test')
self.mock_module.func = lambda x: x
sys.modules['legacy_builder_test'] = self.mock_module
LegacyPatchWrapper._migration_warning_shown = False
def tearDown(self):
"""Clean up mock module."""
if 'legacy_builder_test' in sys.modules:
del sys.modules['legacy_builder_test']
def test_basic_builder(self):
"""Test basic builder creation."""
builder = LegacyPatcherBuilder()
self.assertEqual(builder._module_patches, {})
self.assertEqual(builder._blacklist, set())
def test_add_module_patch(self):
"""Test add_module_patch method."""
def my_patch(module, options):
module.func = lambda x: x * 2
builder = LegacyPatcherBuilder()
builder.add_module_patch("legacy_builder_test", Patch(my_patch))
self.assertIn("legacy_builder_test", builder._module_patches)
self.assertEqual(len(builder._module_patches["legacy_builder_test"]), 1)
def test_chained_add_module_patch(self):
"""Test chained add_module_patch calls."""
def patch1(module, options):
pass
def patch2(module, options):
pass
builder = (
LegacyPatcherBuilder()
.add_module_patch("module1", Patch(patch1))
.add_module_patch("module2", Patch(patch2))
)
self.assertIn("module1", builder._module_patches)
self.assertIn("module2", builder._module_patches)
def test_disable_patches(self):
"""Test disable_patches method."""
builder = LegacyPatcherBuilder()
builder.disable_patches("patch1", "patch2")
self.assertIn("patch1", builder._blacklist)
self.assertIn("patch2", builder._blacklist)
def test_with_profiling(self):
"""Test with_profiling method."""
builder = LegacyPatcherBuilder()
builder.with_profiling("/path/to/prof", level=1, skip_first=50)
self.assertIsNotNone(builder._profiling_options)
self.assertEqual(builder._profiling_options['path'], "/path/to/prof")
self.assertEqual(builder._profiling_options['level'], 1)
self.assertEqual(builder._profiling_options['skip_first'], 50)
def test_brake_at(self):
"""Test brake_at method."""
builder = LegacyPatcherBuilder()
builder.brake_at(100)
self.assertEqual(builder._brake_step, 100)
def test_build_returns_legacy_patcher(self):
"""Test build method returns _LegacyPatcher."""
builder = LegacyPatcherBuilder()
patcher = builder.build()
self.assertIsNotNone(patcher)
self.assertTrue(hasattr(patcher, 'apply'))
self.assertTrue(hasattr(patcher, '__enter__'))
self.assertTrue(hasattr(patcher, '__exit__'))
def test_build_creates_new_patcher_internally(self):
"""Test that build() creates a new Patcher internally."""
builder = LegacyPatcherBuilder()
legacy_patcher = builder.build()
self.assertIsInstance(legacy_patcher._patcher, Patcher)
def test_full_workflow(self):
"""Test full legacy workflow: build -> apply."""
def my_patch(module, options):
module.func = lambda x: x * 10
builder = LegacyPatcherBuilder()
builder.add_module_patch("legacy_builder_test", Patch(my_patch))
with builder.build() as patcher:
self.assertEqual(self.mock_module.func(5), 50)
class TestPatcherBuilderAlias(unittest.TestCase):
"""Test that PatcherBuilder is an alias for LegacyPatcherBuilder."""
def test_patcherbuilder_is_alias(self):
"""Test PatcherBuilder is LegacyPatcherBuilder."""
self.assertIs(PatcherBuilder, LegacyPatcherBuilder)
def test_patcher_module_exposes_legacy_alias(self):
"""Test mx_driving.patcher.patcher lazily exposes PatcherBuilder."""
self.assertIs(getattr(_patcher_module, "PatcherBuilder"), LegacyPatcherBuilder)
class TestLegacyPatcherDelegation(unittest.TestCase):
"""
Test that _LegacyPatcher properly delegates to new Patcher.
"""
def setUp(self):
"""Create mock modules for testing."""
self.mock_module = types.ModuleType('delegation_test')
self.mock_module.func = lambda x: x
sys.modules['delegation_test'] = self.mock_module
LegacyPatchWrapper._migration_warning_shown = False
def tearDown(self):
"""Clean up mock modules."""
if 'delegation_test' in sys.modules:
del sys.modules['delegation_test']
def test_apply_delegates_to_patcher(self):
"""Test that apply() delegates to the internal Patcher."""
def my_patch(module, options):
module.func = lambda x: x * 5
builder = LegacyPatcherBuilder()
builder.add_module_patch("delegation_test", Patch(my_patch))
legacy_patcher = builder.build()
self.assertFalse(legacy_patcher.is_applied)
legacy_patcher.apply()
self.assertTrue(legacy_patcher.is_applied)
self.assertEqual(self.mock_module.func(10), 50)
def test_context_manager_delegates(self):
"""Test that context manager delegates to Patcher."""
def my_patch(module, options):
module.func = lambda x: x * 3
builder = LegacyPatcherBuilder()
builder.add_module_patch("delegation_test", Patch(my_patch))
with builder.build() as patcher:
self.assertTrue(patcher.is_applied)
self.assertEqual(self.mock_module.func(10), 30)
class TestMixedUsage(unittest.TestCase):
"""
Test mixed usage of old and new patcher APIs.
Scenarios:
- Using old Patch(func) with new Patcher
- Using new Patch classes with old PatcherBuilder
- Combining both in the same application
"""
def setUp(self):
"""Create mock modules for testing."""
self.mock_module1 = types.ModuleType('mixed_test_module1')
self.mock_module1.func = lambda x: x
sys.modules['mixed_test_module1'] = self.mock_module1
self.mock_module2 = types.ModuleType('mixed_test_module2')
self.mock_module2.func = lambda x: x
sys.modules['mixed_test_module2'] = self.mock_module2
LegacyPatchWrapper._migration_warning_shown = False
def tearDown(self):
"""Clean up mock modules."""
for name in ['mixed_test_module1', 'mixed_test_module2']:
if name in sys.modules:
del sys.modules[name]
def test_old_patch_with_new_patcher(self):
"""Test using old-style Patch(func) with new Patcher via LegacyPatch."""
def my_patch(module, options):
module.func = lambda x: x * 5
patcher = Patcher()
patcher.add(LegacyPatch(my_patch, target_module="mixed_test_module1"))
patcher.apply()
self.assertEqual(self.mock_module1.func(10), 50)
def test_new_patch_class_with_old_builder(self):
"""Test using new-style Patch class with old PatcherBuilder."""
class MyNewPatch(Patch):
name = "my_new_patch"
@classmethod
def patches(cls, options=None):
return [
AtomicPatch("mixed_test_module2.func", lambda x: x * 7)
]
def apply_new_patch(module, options):
MyNewPatch.apply_all()
builder = LegacyPatcherBuilder()
builder.add_module_patch("mixed_test_module2", Patch(apply_new_patch))
with builder.build():
self.assertEqual(self.mock_module2.func(10), 70)
def test_both_apis_in_same_session(self):
"""Test using both old and new APIs in the same session."""
patcher = Patcher()
patcher.add(AtomicPatch("mixed_test_module1.func", lambda x: x * 2))
patcher.apply()
self.assertEqual(self.mock_module1.func(10), 20)
def old_patch(module, options):
module.func = lambda x: x * 3
builder = LegacyPatcherBuilder()
builder.add_module_patch("mixed_test_module2", Patch(old_patch))
with builder.build():
self.assertEqual(self.mock_module2.func(10), 30)
class TestPatchApplyAll(unittest.TestCase):
"""
Test Patch.apply_all() classmethod.
This method allows applying a Patch class directly without Patcher.
"""
def setUp(self):
"""Create mock module for testing."""
self.mock_module = types.ModuleType('apply_all_test')
self.mock_module.func1 = lambda x: x
self.mock_module.func2 = lambda x: x
sys.modules['apply_all_test'] = self.mock_module
def tearDown(self):
"""Clean up mock module."""
if 'apply_all_test' in sys.modules:
del sys.modules['apply_all_test']
def test_apply_all_basic(self):
"""Test basic apply_all usage."""
class MyPatch(Patch):
name = "my_patch"
@classmethod
def patches(cls, options=None):
return [
AtomicPatch("apply_all_test.func1", lambda x: x * 2),
AtomicPatch("apply_all_test.func2", lambda x: x * 3),
]
results = MyPatch.apply_all()
self.assertEqual(len(results), 2)
self.assertEqual(self.mock_module.func1(10), 20)
self.assertEqual(self.mock_module.func2(10), 30)
def test_apply_all_with_options(self):
"""Test apply_all with options."""
class MyPatch(Patch):
name = "my_patch"
@classmethod
def patches(cls, options=None):
multiplier = (options or {}).get('multiplier', 10)
return [
AtomicPatch("apply_all_test.func1", lambda x, m=multiplier: x * m),
]
results = MyPatch.apply_all(options={'multiplier': 5})
self.assertEqual(len(results), 1)
self.assertEqual(self.mock_module.func1(10), 50)
class TestLegacyPatchFunctions(unittest.TestCase):
"""
Test old-style patch functions delegation.
The old-style patch functions (msda, dc, mdc, etc.) should delegate
to the new Patch class implementations.
"""
def test_patch_functions_are_callable(self):
"""Test that old-style patch functions are callable."""
msda = _patcher_init_module.msda
dc = _patcher_init_module.dc
mdc = _patcher_init_module.mdc
batch_matmul = _patcher_init_module.batch_matmul
index = _patcher_init_module.index
self.assertTrue(callable(msda))
self.assertTrue(callable(dc))
self.assertTrue(callable(mdc))
self.assertTrue(callable(batch_matmul))
self.assertTrue(callable(index))
def test_patch_functions_signature(self):
"""Test that old-style patch functions have correct signature."""
import inspect
msda = _patcher_init_module.msda
sig = inspect.signature(msda)
params = list(sig.parameters.keys())
self.assertEqual(len(params), 2)
self.assertEqual(params[0], 'module')
self.assertEqual(params[1], 'options')
class TestDefaultPatcherBuilder(unittest.TestCase):
"""Test default_patcher_builder proxy."""
def test_default_patcher_builder_exists(self):
"""Test that default_patcher_builder is accessible."""
default_patcher_builder = _legacy_module.default_patcher_builder
self.assertIsNotNone(default_patcher_builder)
def test_default_patcher_builder_has_methods(self):
"""Test that default_patcher_builder has expected methods."""
default_patcher_builder = _legacy_module.default_patcher_builder
self.assertTrue(hasattr(default_patcher_builder, 'add_module_patch'))
self.assertTrue(hasattr(default_patcher_builder, 'disable_patches'))
self.assertTrue(hasattr(default_patcher_builder, 'build'))
self.assertTrue(hasattr(default_patcher_builder, 'with_profiling'))
self.assertTrue(hasattr(default_patcher_builder, 'brake_at'))
def test_default_patch_classes_put_numpy_before_mmdet3d(self):
"""Default patch order should restore NumPy aliases before mmdet3d patches."""
default_classes = _patcher_init_module._DEFAULT_PATCH_CLASSES
numpy_idx = default_classes.index(_patcher_init_module.NumpyCompat)
dataset_idx = default_classes.index(_patcher_init_module.NuScenesDataset)
metric_idx = default_classes.index(_patcher_init_module.NuScenesMetric)
self.assertLess(numpy_idx, dataset_idx)
self.assertLess(numpy_idx, metric_idx)
def test_default_patcher_builder_mirrors_numpy_before_mmdet3d(self):
"""Legacy default_patcher_builder should mirror the same safe ordering."""
default_patcher_builder = _legacy_module.default_patcher_builder
legacy_patcher = default_patcher_builder.build()
collected = legacy_patcher._patcher._collect_all_patches()
parent_names = [getattr(patch, "_parent_name", "") for patch, _ in collected]
numpy_idx = parent_names.index("numpy_compat")
dataset_idx = parent_names.index("nuscenes_dataset")
metric_idx = parent_names.index("nuscenes_metric")
self.assertLess(numpy_idx, dataset_idx)
self.assertLess(numpy_idx, metric_idx)
class TestPatchFunctionToClassMapping(unittest.TestCase):
"""Test the mapping from old-style patch functions to new Patch classes."""
def test_legacy_name_to_class_mapping_exists(self):
"""Test _LEGACY_NAME_TO_CLASS mapping exists and contains expected entries."""
legacy_name_to_class = _patcher_init_module._LEGACY_NAME_TO_CLASS
known_names = ["msda", "dc", "mdc", "batch_matmul", "index"]
for name in known_names:
self.assertIn(name, legacy_name_to_class)
self.assertTrue(isinstance(legacy_name_to_class[name], type))
def test_legacy_name_to_class_for_unknown_function(self):
"""Test _LEGACY_NAME_TO_CLASS returns KeyError for unknown functions."""
legacy_name_to_class = _patcher_init_module._LEGACY_NAME_TO_CLASS
self.assertNotIn("unknown_patch_function", legacy_name_to_class)
class TestLoggingIntegration(unittest.TestCase):
"""Test that logging works through the legacy API."""
def setUp(self):
"""Create mock module for testing."""
self.mock_module = types.ModuleType('logging_test')
self.mock_module.func = lambda x: x
sys.modules['logging_test'] = self.mock_module
LegacyPatchWrapper._migration_warning_shown = False
def tearDown(self):
"""Clean up mock module."""
if 'logging_test' in sys.modules:
del sys.modules['logging_test']
def test_legacy_wrapper_shows_migration_warning(self):
"""Test that LegacyPatchWrapper shows migration warning."""
LegacyPatchWrapper._migration_warning_shown = False
def my_patch(module, options):
pass
wrapper1 = LegacyPatchWrapper(my_patch)
self.assertTrue(LegacyPatchWrapper._migration_warning_shown)
def another_patch(module, options):
pass
wrapper2 = LegacyPatchWrapper(another_patch)
self.assertTrue(LegacyPatchWrapper._migration_warning_shown)
def test_legacy_warning_mentions_compatibility_notice(self):
"""Migration warning should explain that it is informational, not a failure."""
LegacyPatchWrapper._migration_warning_shown = False
captured = []
original_info = patcher_logger.info
def fake_info(message):
captured.append(message)
patcher_logger.info = fake_info
try:
def my_patch(module, options):
pass
LegacyPatchWrapper(my_patch)
finally:
patcher_logger.info = original_info
self.assertEqual(len(captured), 1)
self.assertIn("compatibility-layer notice", captured[0])
self.assertIn("not a patch failure", captured[0])
patcher_logger = _patcher_logger.patcher_logger
class TestPatchNameAutoDefault(unittest.TestCase):
"""Test Patch.name auto-default via __init_subclass__."""
def test_name_defaults_to_class_name(self):
"""Patch subclass without explicit name gets cls.__name__."""
class MyCustomPatch(Patch):
@classmethod
def patches(cls, options=None):
return []
self.assertEqual(MyCustomPatch.name, "MyCustomPatch")
def test_explicit_name_preserved(self):
"""Patch subclass with explicit name keeps it."""
class MyPatch(Patch):
name = "my_explicit_name"
@classmethod
def patches(cls, options=None):
return []
self.assertEqual(MyPatch.name, "my_explicit_name")
def test_two_subclasses_independent(self):
"""Each subclass gets its own default name."""
class AlphaPatch(Patch):
@classmethod
def patches(cls, options=None):
return []
class BetaPatch(Patch):
@classmethod
def patches(cls, options=None):
return []
self.assertEqual(AlphaPatch.name, "AlphaPatch")
self.assertEqual(BetaPatch.name, "BetaPatch")
class TestLegacyPatchReadableNames(unittest.TestCase):
"""LegacyPatch should infer readable names for internal helper closures."""
def test_internal_apply_name_infers_from_builder(self):
def build_mmcv_epoch_runner_patch():
def _apply(module, _options):
return None
return LegacyPatch(_apply, target_module="mmcv")
patch = build_mmcv_epoch_runner_patch()
self.assertEqual(patch.name, "mmcv_epoch_runner")
def test_explicit_patch_name_override_wins(self):
def builder():
def _apply(module, _options):
return None
_apply.__patch_name__ = "custom_patch_name"
return LegacyPatch(_apply, target_module="mmcv")
patch = builder()
self.assertEqual(patch.name, "custom_patch_name")
class TestPatcherDisableEnhanced(unittest.TestCase):
"""Test Patcher.disable() with Patch classes and BasePatch instances."""
def test_disable_by_patch_class(self):
"""patcher.disable(PatchClass) works."""
mock_mod = types.ModuleType("disable_test_mod")
mock_mod.func = lambda x: x
sys.modules["disable_test_mod"] = mock_mod
try:
class MyPatch(Patch):
name = "disable_test"
@classmethod
def patches(cls, options=None):
return [AtomicPatch("disable_test_mod.func", lambda x: x * 99)]
p = Patcher()
p.add(MyPatch)
p.disable(MyPatch)
p.apply()
self.assertEqual(mock_mod.func(1), 1)
finally:
del sys.modules["disable_test_mod"]
def test_disable_by_patch_instance(self):
"""patcher.disable(atomic_patch_instance) works."""
mock_mod = types.ModuleType("disable_inst_mod")
mock_mod.func = lambda x: x
sys.modules["disable_inst_mod"] = mock_mod
try:
ap = AtomicPatch("disable_inst_mod.func", lambda x: x * 99)
p = Patcher()
p.add(ap)
p.disable(ap)
p.apply()
self.assertEqual(mock_mod.func(1), 1)
finally:
del sys.modules["disable_inst_mod"]
class TestGlobalInsertionOrder(unittest.TestCase):
"""Regression tests for global insertion order (tech debt #02)."""
def setUp(self):
self.mock_mod = types.ModuleType("order_global_test")
self.mock_mod.func = lambda: "original"
sys.modules["order_global_test"] = self.mock_mod
def tearDown(self):
sys.modules.pop("order_global_test", None)
def test_direct_then_class_order(self):
"""add(direct); add(Class) → Class should win (added last)."""
class MyPatch(Patch):
name = "order_class"
@classmethod
def patches(cls, options=None):
return [AtomicPatch("order_global_test.func", lambda: "from-class")]
direct = AtomicPatch("order_global_test.func", lambda: "from-direct")
p = Patcher()
p.add(direct)
p.add(MyPatch)
p.apply()
self.assertEqual(self.mock_mod.func(), "from-class",
"Patch class added AFTER direct should win")
def test_class_then_direct_order(self):
"""add(Class); add(direct) → direct should win (added last)."""
class MyPatch(Patch):
name = "order_class"
@classmethod
def patches(cls, options=None):
return [AtomicPatch("order_global_test.func", lambda: "from-class")]
direct = AtomicPatch("order_global_test.func", lambda: "from-direct")
p = Patcher()
p.add(MyPatch)
p.add(direct)
p.apply()
self.assertEqual(self.mock_mod.func(), "from-direct",
"Direct patch added AFTER class should win")
def test_interleaved_order(self):
"""Interleaved add: direct → class → direct → collection order correct."""
d1 = AtomicPatch("order_global_test.func", lambda: "d1")
class CP(Patch):
name = "cp"
@classmethod
def patches(cls, options=None):
return [AtomicPatch("order_global_test.func", lambda: "cp")]
d2 = AtomicPatch("order_global_test.func", lambda: "d2")
p = Patcher()
p.add(d1)
p.add(CP)
p.add(d2)
names = [patch.name for patch, _ in p._collect_all_patches()]
d1_idx = names.index("order_global_test.func")
cp_idx = next(i for i, n in enumerate(names) if n == "order_global_test.func" and i > d1_idx)
self.assertTrue(d1_idx < cp_idx, "d1 should come before class patch")
class TestConflictsWithEnforcement(unittest.TestCase):
"""Regression tests for conflicts_with enforcement (tech debt #03)."""
def test_conflict_raises_on_apply(self):
"""Two conflicting patches should raise ValueError on apply."""
class PatchA(Patch):
name = "patch_a"
conflicts_with = ["patch_b"]
@classmethod
def patches(cls, options=None):
return []
class PatchB(Patch):
name = "patch_b"
@classmethod
def patches(cls, options=None):
return []
p = Patcher()
p.add(PatchA, PatchB)
with self.assertRaises(ValueError) as ctx:
p.apply()
self.assertIn("patch_a", str(ctx.exception))
self.assertIn("patch_b", str(ctx.exception))
self.assertIn("default_patcher", str(ctx.exception))
self.assertIn("patcher.disable('patch_b').add(PatchA)", str(ctx.exception))
def test_conflict_ok_when_one_disabled(self):
"""Conflict is resolved when one patch is disabled."""
mock_mod = types.ModuleType("conflict_ok_mod")
mock_mod.func = lambda: "original"
sys.modules["conflict_ok_mod"] = mock_mod
try:
class PatchA(Patch):
name = "patch_a"
conflicts_with = ["patch_b"]
@classmethod
def patches(cls, options=None):
return [AtomicPatch("conflict_ok_mod.func", lambda: "a")]
class PatchB(Patch):
name = "patch_b"
@classmethod
def patches(cls, options=None):
return [AtomicPatch("conflict_ok_mod.func", lambda: "b")]
p = Patcher()
p.add(PatchA, PatchB)
p.disable(PatchA)
p.apply()
self.assertEqual(mock_mod.func(), "b")
finally:
sys.modules.pop("conflict_ok_mod", None)
def test_no_conflict_no_error(self):
"""Patches without conflicts_with should work normally."""
class PatchX(Patch):
name = "patch_x"
@classmethod
def patches(cls, options=None):
return []
class PatchY(Patch):
name = "patch_y"
@classmethod
def patches(cls, options=None):
return []
p = Patcher()
p.add(PatchX, PatchY)
p.apply()
def test_bidirectional_conflict_detected(self):
"""Conflict should be detected regardless of add order."""
class PatchA(Patch):
name = "patch_a"
conflicts_with = ["patch_b"]
@classmethod
def patches(cls, options=None):
return []
class PatchB(Patch):
name = "patch_b"
@classmethod
def patches(cls, options=None):
return []
p1 = Patcher()
p1.add(PatchA, PatchB)
with self.assertRaises(ValueError):
p1.apply()
p2 = Patcher()
p2.add(PatchB, PatchA)
with self.assertRaises(ValueError):
p2.apply()
def test_conflict_warning_on_add_includes_default_patcher_guidance(self):
"""Adding a conflicting patch should warn with a simple disable-then-add hint."""
class PatchA(Patch):
name = "patch_a"
conflicts_with = ["patch_b"]
@classmethod
def patches(cls, options=None):
return []
class PatchB(Patch):
name = "patch_b"
@classmethod
def patches(cls, options=None):
return []
p = Patcher()
p.add(PatchB)
with self.assertLogs("mx_driving.patcher", level="WARNING") as logs:
p.add(PatchA)
output = "\n".join(logs.output)
self.assertIn("patch_a", output)
self.assertIn("patch_b", output)
self.assertIn("default_patcher", output)
self.assertIn("patcher.disable('patch_b').add(PatchA)", output)
class TestLegacyUnknownLabelFix(unittest.TestCase):
"""
Regression tests for legacy unknown-label fix.
Covers plan Section 8 requirements:
1. Custom legacy patch no longer grouped as 'unknown'
2. Built-in legacy patch expansion with _parent_name
3. Mixed built-in + custom order preservation
"""
def setUp(self):
LegacyPatchWrapper._migration_warning_shown = True
def test_custom_legacy_patch_not_unknown(self):
"""Custom legacy patch should use func name as group, not 'unknown'."""
mock_mod = types.ModuleType("custom_label_test")
mock_mod.target_func = lambda: "original"
sys.modules["custom_label_test"] = mock_mod
try:
def my_custom_patch(module, options):
module.target_func = lambda: "patched"
builder = LegacyPatcherBuilder()
builder.add_module_patch("custom_label_test", Patch(my_custom_patch))
legacy_patcher = builder.build()
inner = legacy_patcher._patcher
all_patches = inner._collect_all_patches()
parent_names = [pn for _, pn in all_patches]
self.assertNotIn("", parent_names,
"parent_name should not be empty (would become 'unknown')")
self.assertIn("my_custom_patch", parent_names,
"Custom legacy patch should use func name as parent_name")
finally:
del sys.modules["custom_label_test"]
def test_builtin_legacy_patch_expanded_with_parent_name(self):
"""Built-in legacy patch should expand to child patches with _parent_name."""
_LEGACY_NAME_TO_CLASS = _patcher_init_module._LEGACY_NAME_TO_CLASS
batch_matmul_func = _patcher_init_module.batch_matmul
self.assertIn(batch_matmul_func.__name__, _LEGACY_NAME_TO_CLASS,
"batch_matmul should be in _LEGACY_NAME_TO_CLASS")
batch_matmul_cls = _LEGACY_NAME_TO_CLASS[batch_matmul_func.__name__]
builder = LegacyPatcherBuilder()
builder.add_module_patch("torch", Patch(batch_matmul_func))
legacy_patcher = builder.build()
inner_patcher = legacy_patcher._patcher
all_patches = inner_patcher._collect_all_patches()
child_patches = [p for p, _ in all_patches]
self.assertTrue(len(child_patches) > 0,
"Built-in legacy patch should expand to child patches")
for cp in child_patches:
self.assertTrue(hasattr(cp, "_parent_name"),
f"Child patch {cp.name} should have _parent_name")
self.assertEqual(cp._parent_name, batch_matmul_cls.name,
f"Child patch _parent_name should be '{batch_matmul_cls.name}'")
def test_mixed_builtin_custom_order_preserved(self):
"""Mixed built-in + custom legacy patches preserve insertion order per module key."""
mock_mod_a = types.ModuleType("order_mod_a")
mock_mod_a.func = lambda: "a"
mock_mod_b = types.ModuleType("order_mod_b")
mock_mod_b.func = lambda: "b"
sys.modules["order_mod_a"] = mock_mod_a
sys.modules["order_mod_b"] = mock_mod_b
_LEGACY_NAME_TO_CLASS = _patcher_init_module._LEGACY_NAME_TO_CLASS
batch_matmul_func = _patcher_init_module.batch_matmul
try:
def custom_first(module, options):
module.func = lambda: "patched_a"
def custom_last(module, options):
pass
builder = LegacyPatcherBuilder()
builder.add_module_patch("order_mod_a", Patch(custom_first))
builder.add_module_patch("torch", Patch(batch_matmul_func))
builder.add_module_patch("order_mod_b", Patch(custom_last))
legacy_patcher = builder.build()
inner_patcher = legacy_patcher._patcher
all_patches = inner_patcher._collect_all_patches()
names = [p.name for p, _ in all_patches]
custom_first_idx = names.index("custom_first")
custom_last_idx = names.index("custom_last")
batch_matmul_cls = _LEGACY_NAME_TO_CLASS[batch_matmul_func.__name__]
bm_child_names = {p.name for p in batch_matmul_cls.patches()}
bm_indices = [i for i, n in enumerate(names) if n in bm_child_names]
self.assertTrue(len(bm_indices) > 0, "batch_matmul should have child patches")
self.assertTrue(all(custom_first_idx < i for i in bm_indices),
"custom_first should come before batch_matmul children")
self.assertTrue(all(i < custom_last_idx for i in bm_indices),
"batch_matmul children should come before custom_last")
finally:
sys.modules.pop("order_mod_a", None)
sys.modules.pop("order_mod_b", None)
class TestPatchMmcvVersion(unittest.TestCase):
"""Regression test for patch_mmcv_version wrapper bug (tech debt #01)."""
def test_patch_mmcv_version_passes_argument(self):
"""patch_mmcv_version should forward expected_version, not builtin str."""
patch_mmcv_version = _patcher_init_module.patch_mmcv_version
import dis
code = patch_mmcv_version.__code__
self.assertNotIn("str", code.co_names,
"patch_mmcv_version should not reference builtin str")
def test_patch_mmcv_version_calls_ensure_correctly(self):
"""patch_mmcv_version("1.6.0") should pass "1.6.0" to ensure_mmcv_version."""
captured = []
original = _patcher_init_module.ensure_mmcv_version
def mock_ensure(version):
captured.append(version)
_patcher_init_module.ensure_mmcv_version = mock_ensure
try:
_patcher_init_module.patch_mmcv_version("1.6.0")
self.assertEqual(captured, ["1.6.0"],
"patch_mmcv_version should forward the version string")
finally:
_patcher_init_module.ensure_mmcv_version = original
def test_ensure_mmcv_version_no_mmcv(self):
"""ensure_mmcv_version should not raise when mmcv is not installed."""
ensure = _patcher_init_module.ensure_mmcv_version
orig = sys.modules.pop("mmcv", None)
try:
ensure("1.6.0")
finally:
if orig is not None:
sys.modules["mmcv"] = orig
if __name__ == "__main__":
unittest.main()