import torch
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu


class Test_issue62(TestUtils):
    def op_func(self, addmm_5, add):
        split = torch.ops.aten.split.Tensor(addmm_5, 1536, 1)
        getitem = split[0]
        getitem_1 = split[1]
        getitem_2 = split[2]
        getitem_3 = split[3]
        getitem_4 = split[4]
        getitem_5 = split[5]

        clone_1 = torch.ops.aten.clone.default(add, memory_format=torch.contiguous_format)
        convert_element_type_25 = torch.ops.prims.convert_element_type.default(clone_1, torch.float32)
        var_mean = torch.ops.aten.var_mean.correction(convert_element_type_25, [2], correction=0, keepdim=True)
        getitem_6 = var_mean[0]
        getitem_7 = var_mean[1]
        add_3 = torch.ops.aten.add.Tensor(getitem_6, 1e-06)
        rsqrt = torch.ops.aten.rsqrt.default(add_3)
        sub = torch.ops.aten.sub.Tensor(clone_1, getitem_7)
        mul_7 = torch.ops.aten.mul.Tensor(sub, rsqrt)
        convert_element_type_26 = torch.ops.prims.convert_element_type.default(mul_7, torch.float16)
        slice_11 = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 9223372036854775807)
        unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_11, 1)
        add_4 = torch.ops.aten.add.Tensor(unsqueeze_2, 1)
        mul_8 = torch.ops.aten.mul.Tensor(convert_element_type_26, add_4)
        slice_12 = torch.ops.aten.slice.Tensor(getitem, 0, 0, 9223372036854775807)
        unsqueeze_3 = torch.ops.aten.unsqueeze.default(slice_12, 1)
        add_5 = torch.ops.aten.add.Tensor(mul_8, unsqueeze_3)
        return add_5

    def test_issue62(self):
        addmm_5 = torch.randn((2, 9216), device='npu:0', dtype=torch.float16)
        add = torch.randn((2, 4096, 1536), device='npu:0', dtype=torch.float16)

        std_ret = self.op_func(addmm_5, add)
        compiled_func = torch.compile(self.op_func, backend="inductor")
        inductor_ret = compiled_func(addmm_5, add)
        self.assertEqual(std_ret, inductor_ret, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
    run_tests()