import importlib
import inspect
import sys
import types
from typing import List, Dict, Union

_MEGATRON_TRAINING_AVAILABLE = None


def is_megatron_training_available():
    """
    Check if megatron.training module is available.
    
    Returns:
        bool: True if megatron.training is available, False otherwise.
    """
    global _MEGATRON_TRAINING_AVAILABLE
    if _MEGATRON_TRAINING_AVAILABLE is not None:
        return _MEGATRON_TRAINING_AVAILABLE
    
    try:
        import megatron.training
        _MEGATRON_TRAINING_AVAILABLE = True
    except ModuleNotFoundError:
        _MEGATRON_TRAINING_AVAILABLE = False
    
    return _MEGATRON_TRAINING_AVAILABLE


def get_func_name(func):
    if isinstance(func, str):
        return func
    return '.'.join((func.__module__, func.__qualname__))


def dummy_function_wrapper(func_name):
    def dummy_function(*args, **kwargs):
        raise RuntimeError('function {} no exist'.format(func_name))

    return dummy_function


class Patch:
    def __init__(self, orig_func_name, new_func, create_dummy):
        split_name = orig_func_name.rsplit('.', 1)
        if len(split_name) == 1:
            self.orig_module_name, self.orig_func_name = orig_func_name, None
        else:
            self.orig_module_name, self.orig_func_name = split_name
        self.orig_module = None
        self.orig_func = None

        self.patch_func = None
        self.final_patch_func = None
        self.wrappers = []
        if new_func is None:
            new_func = dummy_function_wrapper(orig_func_name)
        self.set_patch_func(new_func)
        self.is_applied = False
        self.create_dummy = create_dummy

    @property
    def orig_func_id(self):
        return id(self.orig_func)

    @property
    def patch_func_id(self):
        return id(self.patch_func)

    def set_patch_func(self, new_func, force_patch=False):
        if hasattr(new_func, '__name__') and new_func.__name__.endswith(('wrapper', 'decorator')):
            if new_func not in self.wrappers:
                self.wrappers.append(new_func)
        else:
            if self.patch_func and not force_patch:
                raise RuntimeError('the patch of {} exist !'.format(self.orig_func_name))
            self.patch_func = new_func
        self.is_applied = False

    def remove_wrappers(self, wrapper_names: Union[str, List[str]] = None):
        if wrapper_names is None:
            self.wrappers.clear()
            return

        if isinstance(wrapper_names, str):
            wrapper_names = [wrapper_names]
        for name in wrapper_names:
            i = 0
            while i < len(self.wrappers):
                if self.wrappers[i].__name__ == name:
                    self.wrappers.pop(i)
                else:
                    i += 1

    def remove_patch(self):
        for key, value in sys.modules.copy().items():
            if 'mindspeed' in key or 'torch.classes' == key:
                continue

            if inspect.isclass(self.orig_module) and hasattr(value, self.orig_module_name.split('.')[-1]):
                value = getattr(value, self.orig_module_name.split('.')[-1])

            if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
                    and id(getattr(value, self.orig_func_name)) == id(self.final_patch_func):
                setattr(value, self.orig_func_name, self.orig_func)
        self.patch_func = None
        self.final_patch_func = None
        self.is_applied = False

    def apply_patch(self):
        if self.is_applied:
            return

        current_module, current_func = Patch.parse_path(self.orig_module_name, self.orig_func_name, self.create_dummy)

        if self.orig_module is None:
            self.orig_module, self.orig_func = current_module, current_func

        final_patch_func = self.orig_func
        if self.patch_func is not None:
            final_patch_func = self.patch_func

        for wrapper in self.wrappers:
            final_patch_func = wrapper(final_patch_func)

        if self.orig_func_name is not None:
            setattr(self.orig_module, self.orig_func_name, final_patch_func)
        for _, value in sys.modules.copy().items():
            if self.orig_func_name is not None and hasattr(value, self.orig_func_name) \
                    and id(getattr(value, self.orig_func_name)) == id(current_func):
                setattr(value, self.orig_func_name, final_patch_func)
        self.is_applied = True
        self.final_patch_func = final_patch_func

    @staticmethod
    def parse_path(module_path, function_name, create_dummy):
        from importlib.machinery import ModuleSpec
        modules = module_path.split('.')
        for i in range(1, len(modules) + 1):
            parent = '.'.join(modules[:i - 1])
            path = '.'.join(modules[:i])
            try:
                importlib.import_module(path)
            except ModuleNotFoundError as e:
                if not parent or not hasattr(importlib.import_module(parent), modules[i - 1]):
                    if not create_dummy:
                        raise ModuleNotFoundError(e) from e
                    sys.modules[path] = types.ModuleType(path)
                    sys.modules[path].__file__ = 'mindspeed.dummy_module.py'
                    sys.modules[path].__spec__ = ModuleSpec(path, None)
                    if parent:
                        setattr(importlib.import_module(parent), modules[i - 1], sys.modules[path])
                else:
                    module = getattr(importlib.import_module(parent), modules[i - 1])
                    if hasattr(module, function_name):
                        return module, getattr(module, function_name)
                    elif create_dummy:
                        return module, dummy_function_wrapper(function_name)
                    else:
                        raise RuntimeError('no exist {} of {}'.format(function_name, module))

        if function_name is not None and not hasattr(sys.modules[module_path], function_name):
            setattr(sys.modules[module_path], function_name, None)
        return sys.modules[module_path], getattr(sys.modules[module_path], function_name) if function_name is not None else None


class MindSpeedPatchesManager:
    patches_info: Dict[str, Patch] = {}

    @staticmethod
    def register_patch(orig_func_name, new_func=None, force_patch=False, create_dummy=False):
        """Patch registration method. When this method is executed, the patch does not take effect in real time.
        It takes effect only after the apply_patches method is invoked. Other details are as follows:

        1. If `orig_func_name` does not exist and create_dummy is set to True, a dummy function is created to ensure
        that the import is normal.
        2. If `orig_func_name` is not None, `orig_func_name` is replaced with `new_func`.
        3. If the `new_func` function name ends with `wrapper` or `decorator`, then `new_func` is decorated on
        `orig_func_name` as a decorator, and the decorator can be superimposed repeatedly.
        4. When force_patch=False, a function cannot be replaced repeatedly (but can be decorated repeatedly),
        otherwise the replacement is overwritten.
        """
        if orig_func_name not in MindSpeedPatchesManager.patches_info:
            MindSpeedPatchesManager.patches_info[orig_func_name] = Patch(orig_func_name, new_func, create_dummy)
        else:
            MindSpeedPatchesManager.patches_info.get(orig_func_name).set_patch_func(new_func, force_patch)

    @staticmethod
    def remove_wrappers(orig_func_name, wrappers_name, remove_check=True):
        """Remove wrapper registered in orig_func_name."""
        if orig_func_name not in MindSpeedPatchesManager.patches_info:
            raise ValueError('The function <{}> not exist.'.format(orig_func_name))

        patch = MindSpeedPatchesManager.patches_info.get(orig_func_name)
        wrappers_len = len(patch.wrappers)
        patch.remove_wrappers(wrappers_name)
        if remove_check and wrappers_len == len(patch.wrappers):
            raise RuntimeError('Remove wrappers has not remove anything.')

    @staticmethod
    def remove_patches():
        for patch in MindSpeedPatchesManager.patches_info.values():
            patch.remove_patch()
            patch.remove_wrappers()

    @staticmethod
    def apply_patches():
        """Apply all patches registered in MindSpeedPatchesManager."""
        for patch in MindSpeedPatchesManager.patches_info.values():
            patch.apply_patch()