import torch
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu


class Test_issue57(TestUtils):
    def op_sum(self, view_12, embedding_1, slice_11):
        permute_7 = torch.ops.aten.permute.default(embedding_1, [2, 0, 1])
        embedding_1 = None
        unsqueeze_4 = torch.ops.aten.unsqueeze.default(permute_7, 0)
        permute_7 = None

        add_5 = torch.ops.aten.add.Tensor(unsqueeze_4, slice_11)
        slice_8 = slice_11 = None
        add_6 = torch.ops.aten.add.Tensor(view_12, add_5)
        view_12 = None
        return add_6

    def test_issue57(self):
        device = 'npu'
        embedding_1 = torch.randn((512, 512, 64), device=device, dtype=torch.float32)
        primals_221 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32)
        view_12 = torch.randn((1, 64, 512, 512), device=device, dtype=torch.float32)
        slice_11 = torch.randn((1, 1, 1, 512), device=device, dtype=torch.float32)

        ref = self.op_sum(view_12, embedding_1, primals_221)
        func = torch.compile(self.op_sum, backend="inductor", dynamic=False)
        calc = func(view_12, embedding_1, primals_221)

        self.assertEqual(ref, calc, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
    run_tests()