import torch
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu


class Test_issue59(TestUtils):
    def layernorm_backward(self, x, y, z):
        sum_0 = torch.sum(x)
        mean = sum_0 / torch.numel(sum_0)
        sub = x - mean
        sqr = sub * sub
        sum_1 = torch.sum(sqr)
        mean_1 = sum_1 / torch.numel(sum_1) + 1e-05
        rsqrt = torch.rsqrt(mean_1)
        mul = sub * rsqrt
        mul_1 = mul * y
        add = mul_1 + z
        mean_2 = rsqrt / torch.numel(rsqrt)
        return mul, add, mean_2

    def test_issue59(self):
        device = 'npu'
        x = torch.randn((1, 1024), device=device, dtype=torch.float32)
        y = torch.randn((1, 1024), device=device, dtype=torch.float32)
        z = torch.randn((1, 1024), device=device, dtype=torch.float32)

        mul, add, mean_2 = self.layernorm_backward(x, y, z)
        func = torch.compile(self.layernorm_backward, backend="inductor", dynamic=False)
        mul_t, add_t, mean_2_t = func(x, y, z)

        self.assertEqual(mul, mul_t, atol=1e-3, rtol=1e-3)
        self.assertEqual(add, add_t, atol=1e-3, rtol=1e-3)
        self.assertEqual(mean_2, mean_2_t, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
    run_tests()