r"""Functional interface"""
__all__ = ["dropout_with_byte_mask"]
import torch
import torch_npu as _VF
from torch_npu.utils._error_code import ErrCode, ops_error
Tensor = torch.Tensor
def dropout_with_byte_mask(input1, p=0.5, training=True, inplace=False):
r"""Applies an NPU compatible dropout_with_byte_mask operation, Only supports npu devices.
This dropout_with_byte_mask method generates stateless random uint8 mask and do dropout according to the mask.
.. note::
The performance is improved only in the device 32 core scenario.
Args:
p: probability of a channel to be zeroed. Default: 0.5
training: apply dropout if is ``True``. Default: ``True``
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
"""
if p < 0. or p > 1.:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p) + ops_error(ErrCode.VALUE))
if inplace:
raise ValueError("dropout probability has no-inplace computing." + ops_error(ErrCode.VALUE))
return (_VF.dropout_with_byte_mask_(input1, p, training)
if inplace
else _VF.dropout_with_byte_mask(input1, p, training))