import torch._ops
from torch._inductor.decomposition import decompositions, pw_cast_for_opmath
from torch._inductor.decomposition import register_decomposition
from torch._prims_common.wrappers import out_wrapper
from .config import get_soc_version, is_ascend950
from .lowering import _add_overload
aten = torch.ops.aten
DECOMPOSITION_OVERLOAD_OP = [
aten.nll_loss_forward,
aten.nll_loss_backward,
aten._log_softmax_backward_data,
aten.addmm,
aten.gelu,
aten.native_layer_norm,
aten.native_dropout,
aten.native_dropout_backward
]
if is_ascend950:
DECOMPOSITION_OVERLOAD_OP.append(aten.max_pool2d_with_indices)
def _register_npu_inductor_decompositons():
overload_op_set = set()
_add_overload(DECOMPOSITION_OVERLOAD_OP, overload_op_set)
for op in overload_op_set:
if (op in decompositions):
del decompositions[op]
@register_decomposition([aten.expm1])
def expm1(x):
tensor = torch.exp(x) - torch.ones_like(x)
return tensor
@register_decomposition([aten.erfc])
def erfc(x):
tensor = torch.ones_like(x) - torch.erf(x)
return tensor
@register_decomposition(aten.native_dropout)
@out_wrapper("out0", "out1")
def native_dropout(tensor_input, p, train):
if torch._inductor.config.fallback_random:
if train and p != 0:
return torch.ops.npu._npu_dropout(tensor_input, p)
return (tensor_input, torch.ones_like(tensor_input, dtype=torch.bool))
else:
from torch._decomp.decompositions import native_dropout
return native_dropout(tensor_input, p, train)
@register_decomposition(aten.native_dropout_backward)
@out_wrapper()
def native_dropout_backward(grad_output, mask, scale):
if torch._inductor.config.fallback_random:
p = 1 if scale == 0 else (1 - 1 / scale)
r = torch.ops.npu.npu_dropout_backward(grad_output, mask, p)
return r
else:
from torch._decomp.decompositions import native_dropout_backward
return native_dropout_backward(grad_output, mask, scale)