import pkgutil
from collections import defaultdict
from importlib import import_module
from typing import Callable, List, Optional
PatchFn = Callable[[], None]
class PatchManager:
"""
Central patch registry for torch_npu.
PatchManager discovers patch modules in two ways:
1. Auto-import modules under torch_npu._init.patches whose names end with
'_patches';
2. Import external patch modules declared by DEFAULT_EXTRA_PATCH_MODULES
or registered through register_patch_module().
Patch modules should register patch functions through:
@PatchManager.register_patch("group_name")
def apply_xxx_patch():
...
"""
DEFAULT_PATCH_ORDER = [
"monkey",
"api",
"distributed",
"dynamo",
"profiler",
"npu",
"warning",
"asd",
]
DEFAULT_EXTRA_PATCH_MODULES = [
]
PATCH_MODULE_SUFFIX = "_patches"
_applied_patch_count = defaultdict(int)
_builtin_patches_registered = False
_custom_full_patch_order: Optional[List[str]] = None
_patch_groups = defaultdict(list)
_patch_modules: List[str] = []
@classmethod
def _add_patch(cls, group: str, fn: PatchFn):
if fn not in cls._patch_groups[group]:
cls._patch_groups[group].append(fn)
@classmethod
def register_patch(cls, group: str, fn: Optional[PatchFn] = None):
"""
Register a patch function into a patch group.
Supports:
PatchManager.register_patch("graph", apply_graph_patch)
and:
@PatchManager.register_patch("graph")
def apply_graph_patch():
...
"""
if not isinstance(group, str) or not group:
raise ValueError("patch group must be a non-empty string")
def decorator(real_fn: PatchFn):
cls._add_patch(group, real_fn)
return real_fn
if fn is not None:
return decorator(fn)
return decorator
@classmethod
def _resolve_patch_order(cls) -> List[str]:
"""
Resolve final patch execution order.
- Built-in groups follow DEFAULT_PATCH_ORDER.
- Groups not listed in DEFAULT_PATCH_ORDER are appended after default groups.
"""
if cls._custom_full_patch_order is not None:
base_order = list(cls._custom_full_patch_order)
else:
base_order = list(cls.DEFAULT_PATCH_ORDER)
extra_groups = [group for group in cls._patch_groups if group not in base_order]
return base_order + extra_groups
@classmethod
def _register_builtin_patches(cls):
"""
Discover and import patch modules.
This method imports:
1. built-in patch modules under torch_npu._init.patches whose names end
with '_patches';
2. default external patch modules declared in DEFAULT_EXTRA_PATCH_MODULES;
3. external patch modules registered by register_patch_module().
Importing a patch module triggers @PatchManager.register_patch(...)
decorators inside that module.
"""
if cls._builtin_patches_registered:
return
import torch_npu._init.patches as patches_pkg
prefix = patches_pkg.__name__ + "."
for module_info in sorted(
pkgutil.iter_modules(patches_pkg.__path__), key=lambda x: x.name
):
module_name = module_info.name
if module_name == "patch_manager":
continue
if not module_name.endswith(cls.PATCH_MODULE_SUFFIX):
continue
import_module(prefix + module_name)
for module_name in cls.DEFAULT_EXTRA_PATCH_MODULES:
cls.register_patch_module(module_name)
for module_name in cls._patch_modules:
import_module(module_name)
cls._builtin_patches_registered = True
@classmethod
def apply_registered_patches(cls, group: str):
"""
Apply newly registered patches in one group.
This supports delayed apply:
if new patch functions are registered after this group was applied,
calling this method again only applies newly added patch functions.
"""
cls._register_builtin_patches()
patches = cls._patch_groups.get(group, [])
start = cls._applied_patch_count[group]
for patch in patches[start:]:
patch()
cls._applied_patch_count[group] = len(patches)
@staticmethod
def _patch_excepthook():
"""
Patch Python global exception hook.
Kept separate from normal patches because it is paired with shutdown
exception handling.
"""
from torch_npu.utils._error_code import _except_handler
_except_handler.patch_excepthook()
@classmethod
def register_patch_module(cls, module_name: str):
"""
Register an extra module that contains @PatchManager.register_patch(...)
decorators. The module will be imported before patches are applied.
"""
if not isinstance(module_name, str) or not module_name:
raise ValueError("patch module name must be a non-empty string")
if module_name not in cls._patch_modules:
cls._patch_modules.append(module_name)
@classmethod
def set_patch_order(cls, order: List[str]):
"""
Override base patch group order.
Must be called before _apply_all_patches.
Registered groups not listed in this order will still be appended after
the base order by _resolve_patch_order().
"""
cls._custom_full_patch_order = list(order)
@classmethod
def clear_for_test(cls):
"""
Test-only helper. Do not use in normal runtime.
"""
cls._patch_groups.clear()
cls._patch_modules.clear()
cls._custom_full_patch_order = None
cls._builtin_patches_registered = False
cls._applied_patch_count.clear()
def _apply_all_patches():
PatchManager._register_builtin_patches()
for group in PatchManager._resolve_patch_order():
PatchManager.apply_registered_patches(group)
PatchManager._patch_excepthook()