import sys
import types

import torch

from torch_npu._init.patches.patch_manager import PatchManager
from torch_npu.contrib.function import npu_functional
from torch_npu.contrib.module import npu_modules
from torch_npu.utils._error_code import ErrCode, pta_error


all_monkey_patches = [
    ["nn.functional", npu_functional],
    ["nn", npu_modules],
]


def _apply_patches(monkey_patches):
    def _getattr(module_list, root_module=torch):
        if len(module_list) <= 1:
            return root_module

        if hasattr(root_module, module_list[0]):
            return _getattr(module_list[1:], getattr(root_module, module_list[0]))

        empty_module_name = f"{root_module.__name__}.{module_list[0]}"
        sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
        setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
        return _getattr(module_list[1:], getattr(root_module, module_list[0]))

    for dest, patch in monkey_patches:
        dest_module = _getattr(dest.split("."), root_module=torch)
        last_module_level = dest.split(".")[-1]

        if not isinstance(patch, types.ModuleType):
            setattr(dest_module, last_module_level, patch)
            continue

        if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
            setattr(dest_module, last_module_level, patch)
            sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
            continue

        if not hasattr(patch, "__all__"):
            raise NotImplementedError(
                "Patch module must have __all__ definition."
                + pta_error(ErrCode.NOT_SUPPORT)
            )

        dest_module = getattr(dest_module, last_module_level)
        for attr in patch.__all__:
            setattr(dest_module, attr, getattr(patch, attr))


@PatchManager.register_patch("monkey")
def apply_monkey_patches():
    _apply_patches(all_monkey_patches)