import warnings
from functools import wraps

from torch_npu._init.patches.patch_manager import PatchManager


_WARN_MSG = {
    "DropoutWithByteMask": (
        "torch.nn.DropoutWithByteMask is deprecated and will be removed in future version. "
        "Use torch_npu.contrib.module.DropoutWithByteMask instead."
    ),
    "dropout_with_byte_mask": (
        "torch.nn.functional.dropout_with_byte_mask is deprecated and will be removed in future version. "
        "Use torch_npu.contrib.function.dropout_with_byte_mask instead."
    ),
}


def _wrap_torch_patch_warning_func(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        warnings.warn(_WARN_MSG[func.__name__])
        return func(*args, **kwargs)

    return wrapper


@PatchManager.register_patch("warning")
def apply_npu_show_warning_patch():
    from torch_npu.utils.utils import _apply_npu_show_warning

    _apply_npu_show_warning()


@PatchManager.register_patch("warning")
def apply_deprecated_api_warning_patch():
    import torch

    torch.nn.DropoutWithByteMask = _wrap_torch_patch_warning_func(
        torch.nn.DropoutWithByteMask
    )
    torch.nn.functional.dropout_with_byte_mask = _wrap_torch_patch_warning_func(
        torch.nn.functional.dropout_with_byte_mask
    )