__all__ = ["DropoutWithByteMask"]





from torch.nn import Module

from torch_npu.utils._error_code import ErrCode, ops_error

from ..function import npu_functional as F





class DropoutWithByteMask(Module):

    r"""Applies an NPU compatible DropoutWithByteMask operation, Only supports npu devices. 

    

    A new module for obtaining the performance benefits of operator fusion in graph mode.



    This DropoutWithByteMask method generates stateless random uint8 mask and do dropout according to the mask.



    .. note::

        The performance is improved only in the device 32 core scenario.



    Args:

        p: probability of an element to be zeroed. Default: 0.5

        inplace: If set to ``True``, will do this operation in-place. Default: ``False``



    Shape:

        - Input: :math:`(*)`. Input can be of any shape

        - Output: :math:`(*)`. Output is of the same shape as input



    Examples::



        >>> m = torch_npu.contrib.module.DropoutWithByteMask(p=0.5)

        >>> input = torch.randn(16, 16)

        >>> output = m(input)

    """



    def __init__(self, p=0.5, inplace=False,

                 max_seed=2 ** 10 - 1):

        super(DropoutWithByteMask, self).__init__()



        if p < 0 or p > 1:

            raise ValueError("dropout probability has to be between 0 and 1, "

                             "but got {}".format(p) + ops_error(ErrCode.VALUE))

        self.p = p

        self.inplace = inplace



    def forward(self, input1):

        return F.dropout_with_byte_mask(input1, self.p, self.training, self.inplace)