989d7c43创建于 2025年4月23日历史提交
import warnings

import torch
import torch.nn as nn
import torch_npu

warnings.filterwarnings(action='once', category=FutureWarning)


class Mish(nn.Module):
    def __init__(self):
        r"""Applies an NPU based Mish operation.

        The calculation formula is as follows:
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))

        .. note::
            Mish exists in the official version  in PyTorch 1.9.0.
            Currently, the PyTorch version adapted for NPU is 1.5.0,
            so Mish needs to be defined as an additional module.

        Examples::
            >>> m = nnn.Mish()
            >>> input_tensor = torch.randn(2, 32, 5, 5)
            >>> output = m(input_tensor)
        """
        super(Mish, self).__init__()
        
        warnings.warn("torch_npu.contrib.module.Mish is deprecated. "
                      "Please use torch.nn.Mish for replacement.", FutureWarning)

    def forward(self, x):
        x = torch_npu.npu_mish(x)
        return x


class SiLU(nn.Module):
    def __init__(self):
        r"""Applies an NPU based Sigmoid Linear Unit (SiLU) function, element-wise.
        The SiLU function is also known as the swish function.

        .. math::
            \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}

        Examples::
            >>> m = nnn.SiLU()
            >>> input_tensor = torch.randn(2, 32, 5, 5)
            >>> output = m(input_tensor)
        """
        super(SiLU, self).__init__()
        
        warnings.warn("torch_npu.contrib.module.SiLU is deprecated. "
                      "Please use torch.nn.SiLU for replacement.", FutureWarning)

    def forward(self, x):
        x = torch_npu.npu_silu(x)
        return x


Swish = SiLU