import importlib
import inspect
import sys
import types
from typing import List, Dict, Union
_MEGATRON_TRAINING_AVAILABLE = None
def is_megatron_training_available():
"""
Check if megatron.training module is available.
Returns:
bool: True if megatron.training is available, False otherwise.
"""
global _MEGATRON_TRAINING_AVAILABLE
if _MEGATRON_TRAINING_AVAILABLE is not None:
return _MEGATRON_TRAINING_AVAILABLE
try:
import megatron.training
_MEGATRON_TRAINING_AVAILABLE = True
except ModuleNotFoundError:
_MEGATRON_TRAINING_AVAILABLE = False
return _MEGATRON_TRAINING_AVAILABLE
def get_func_name(func):
if isinstance(func, str):
return func
return '.'.join((func.__module__, func.__qualname__))
def dummy_function_wrapper(func_name):
def dummy_function(*args, **kwargs):
raise RuntimeError('function {} no exist'.format(func_name))
return dummy_function
class Patch:
def __init__(self, orig_func_name, new_func, create_dummy):
split_name = orig_func_name.rsplit('.', 1)
if len(split_name) == 1:
self.orig_module_name, self.orig_func_name = orig_func_name, None
else:
self.orig_module_name, self.orig_func_name = split_name
self.orig_module = None
self.orig_func = None
self.patch_func = None
self.final_patch_func = None
self.wrappers = []
if new_func is None:
new_func = dummy_function_wrapper(orig_func_name)
self.set_patch_func(new_func)
self.is_applied = False
self.create_dummy = create_dummy
@property
def orig_func_id(self):
return id(self.orig_func)
@property
def patch_func_id(self):
return id(self.patch_func)
def set_patch_func(self, new_func, force_patch=False):
if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')):
if new_func not in self.wrappers:
self.wrappers.append(new_func)
else:
if self.patch_func and not force_patch:
raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
self.patch_func = new_func
self.is_applied = False
def remove_wrappers(self, wrapper_names: Union[str, List[str]] = None):
if wrapper_names is None:
self.wrappers.clear()
return
if isinstance(wrapper_names, str):
wrapper_names = [wrapper_names]
for name in wrapper_names:
i = 0
while i < len(self.wrappers):
if self.wrappers[i].__name__ == name:
self.wrappers.pop(i)
else:
i += 1
def remove_patch(self):
for key, value in sys.modules.copy().items():
if 'mindspeed' in key or 'torch.classes' == key:
continue
if inspect.isclass(self.orig_module) and hasattr(value, self.orig_module_name.split('.')[-1]):
value = getattr(value, self.orig_module_name.split('.')[-1])
if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
and id(getattr(value, self.orig_func_name)) == id(self.final_patch_func):
setattr(value, self.orig_func_name, self.orig_func)
self.patch_func = None
self.final_patch_func = None
self.is_applied = False
def apply_patch(self):
if self.is_applied:
return
current_module, current_func = Patch.parse_path(self.orig_module_name, self.orig_func_name, self.create_dummy)
if self.orig_module is None:
self.orig_module, self.orig_func = current_module, current_func
final_patch_func = self.orig_func
if self.patch_func is not None:
final_patch_func = self.patch_func
for wrapper in self.wrappers:
final_patch_func = wrapper(final_patch_func)
if self.orig_func_name is not None:
setattr(self.orig_module, self.orig_func_name, final_patch_func)
for _, value in sys.modules.copy().items():
if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
and id(getattr(value, self.orig_func_name)) == id(current_func):
setattr(value, self.orig_func_name, final_patch_func)
self.is_applied = True
self.final_patch_func = final_patch_func
@staticmethod
def parse_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
modules = module_path.split('.')
for i in range(1, len(modules) + 1):
parent = '.'.join(modules[:i - 1])
path = '.'.join(modules[:i])
try:
importlib.import_module(path)
except ModuleNotFoundError as e:
if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
if not create_dummy:
raise ModuleNotFoundError(e) from e
sys.modules[path] = types.ModuleType(path)
sys.modules[path].__file__ = 'mindspeed.dummy_module.py'
sys.modules[path].__spec__ = ModuleSpec(path, None)
if parent:
setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
else:
module = getattr(importlib.import_module(parent), modules[i - 1])
if hasattr(module, function_name):
return module, getattr(module, function_name)
elif create_dummy:
return module, dummy_function_wrapper(function_name)
else:
raise RuntimeError('no exist {} of {}'.format(function_name, module))
if function_name is not None and not hasattr(sys.modules[module_path], function_name):
setattr(sys.modules[module_path], function_name, None)
return sys.modules[module_path], getattr(sys.modules[module_path], function_name) if function_name is not None else None
class MindSpeedPatchesManager:
patches_info: Dict[str, Patch] = {}
@staticmethod
def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
"""Patch registration method. When this method is executed, the patch does not take effect in real time.
It takes effect only after the apply_patches method is invoked. Other details are as follows:
1. If `orig_func_name` does not exist and create_dummy is set to True, a dummy function is created to ensure
that the import is normal.
2. If `orig_func_name` is not None, `orig_func_name` is replaced with `new_func`.
3. If the `new_func` function name ends with `wrapper` or `decorator`, then `new_func` is decorated on
`orig_func_name` as a decorator, and the decorator can be superimposed repeatedly.
4. When force_patch=False, a function cannot be replaced repeatedly (but can be decorated repeatedly),
otherwise the replacement is overwritten.
"""
if orig_func_name not in MindSpeedPatchesManager.patches_info:
MindSpeedPatchesManager.patches_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
else:
MindSpeedPatchesManager.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)
@staticmethod
def remove_wrappers(orig_func_name, wrappers_name, remove_check=True):
"""Remove wrapper registered in orig_func_name."""
if orig_func_name not in MindSpeedPatchesManager.patches_info:
raise ValueError('The function <{}> not exist.'.format(orig_func_name))
patch = MindSpeedPatchesManager.patches_info.get(orig_func_name)
wrappers_len = len(patch.wrappers)
patch.remove_wrappers(wrappers_name)
if remove_check and wrappers_len == len(patch.wrappers):
raise RuntimeError('Remove wrappers has not remove anything.')
@staticmethod
def remove_patches():
for patch in MindSpeedPatchesManager.patches_info.values():
patch.remove_patch()
patch.remove_wrappers()
@staticmethod
def apply_patches():
"""Apply all patches registered in MindSpeedPatchesManager."""
for patch in MindSpeedPatchesManager.patches_info.values():
patch.apply_patch()