import math
import torch
from torch._inductor import decomposition as inductor_decomp
from torch._decomp import remove_decompositions
aten = torch.ops.aten
prims = torch.ops.prims
quantized = torch.ops.quantized
decomps_to_exclude_npu = [
aten._batch_norm_no_update,
aten._batch_norm_with_update,
aten._batch_norm_with_update_functional,
aten._log_softmax,
aten._log_softmax_backward_data,
aten._softmax,
aten._softmax_backward_data,
aten.batch_norm_backward,
aten.convolution_backward,
aten.embedding,
aten.embedding_backward,
aten.embedding_dense_backward,
aten.gelu.default,
aten.gelu_backward.default,
aten.grid_sampler_2d,
aten.grid_sampler_2d_backward,
aten.linalg_vector_norm,
aten.max_pool2d_with_indices,
aten.max_pool2d_with_indices_backward,
aten.native_batch_norm,
aten.native_group_norm,
aten.nll_loss2d_backward,
aten.nll_loss2d_forward,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.reflection_pad2d,
aten.reflection_pad2d_backward,
aten.slice.Tensor,
aten.triu,
aten.upsample_bilinear2d,
aten.upsample_bilinear2d_backward,
aten.upsample_nearest1d,
aten.upsample_nearest1d_backward,
aten.upsample_nearest2d,
aten.upsample_nearest2d_backward,
aten.upsample_nearest3d,
aten.upsample_nearest3d_backward,
torch.ops.npu.npu_rotary_mul,
torch.ops.npu.npu_rotary_mul_backward,
]
FP32_MIN_V2 = -8.8
FP32_MAX_V2 = 8.8
DOUBLE_X = 2.0
def tanh(a):
"""
y = (exp(2x) - 1) / (exp(2x) + 1)
with x clipped to [-8.8, 8.8] in float32 before multiply-by-2.
"""
orig_dtype = a.dtype
if orig_dtype != torch.float32:
a = a.to(torch.float32)
x = torch.clamp(a, min=FP32_MIN_V2, max=FP32_MAX_V2)
x2 = x * DOUBLE_X
e2x = torch.exp(x2)
out = (e2x - 1.0) / (e2x + 1.0)
if orig_dtype != torch.float32:
out = out.to(orig_dtype)
return out
def gelu(a: torch.Tensor, approximate: str = "none"):
"""
y = -sqrt(8/pi) * (x + 0.044715 * x^3)
out = x / (1 + exp(y))
"""
orig_dtype = a.dtype
if orig_dtype != torch.float32:
a = a.to(torch.float32)
M_SQRT2 = math.sqrt(2)
M_2_SQRTPI = 2.0 / math.sqrt(math.pi)
kBeta = M_SQRT2 * M_2_SQRTPI
kKappa = 0.044715
a_cube = a * a * a
inner = a + kKappa * a_cube
y = -kBeta * inner
out = a / (1.0 + torch.exp(y))
if orig_dtype != torch.float32:
out = out.to(orig_dtype)
return out
def gelu_backward(grad, self, approximate: str = "none"):
orig_dtype = grad.dtype
if orig_dtype != torch.float32:
grad = grad.to(torch.float32)
self = self.to(torch.float32)
M_SQRT2 = math.sqrt(2)
M_2_SQRTPI = 2.0 / math.sqrt(math.pi)
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
kKappa = 0.044715
x_sq = self * self
x_cube = x_sq * self
inner = kBeta * (self + kKappa * x_cube)
tanh_inner = torch.tanh(inner)
left = 0.5 * self
right = 1.0 + tanh_inner
left_derivative = 0.5 * right
tanh_derivative = (tanh_inner * tanh_inner) * -1.0 + 1.0
inner_derivative = kBeta * (1.0 + 3.0 * kKappa * x_sq)
right_derivative = left * tanh_derivative * inner_derivative
out = grad * (left_derivative + right_derivative)
if orig_dtype != torch.float32:
out = out.to(orig_dtype)
return out
def sigmoid(a: torch.Tensor) -> torch.Tensor:
orig_dtype = a.dtype
if orig_dtype != torch.float32:
a = a.to(torch.float32)
out = 1 / (1.0 + torch.exp(torch.neg(a)))
if orig_dtype != torch.float32:
out = out.to(orig_dtype)
return out
def patch_decomp():
remove_decompositions(inductor_decomp.decompositions, decomps_to_exclude_npu)
inductor_decomp.register_decomposition([aten.sigmoid.default])(sigmoid)
inductor_decomp.register_decomposition([aten.gelu_backward.default])(gelu_backward)
inductor_decomp.register_decomposition([aten.gelu.default])(gelu)
inductor_decomp.register_decomposition([aten.tanh.default])(tanh)