import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu


class TestAttnCp(TestUtils):
    shape = (8, 8, 256, 128)
    dim = -1

    def foo(self, a, b, c):
        y = a + b
        y = y.sum(self.dim)
        y = y.unsqueeze(self.dim)
        y = y.broadcast_to(self.shape) + b
        y = c + y.permute(0, 1, 3, 2)
        return y


    def test_pointwise_cases(self):
        a, b = [torch.randn(self.shape, dtype=torch.float32, device="npu") for _ in range(2)]
        d = torch.randn(self.shape, dtype=torch.float32, device="npu")
        c = d.permute(0, 1, 3, 2).contiguous()
        func = torch.compile(self.foo, backend="inductor")
        r = func(a, b, c)
        r1 = self.foo(a, b, c)
        self.assertEqual(r, r1, atol=1e-3, rtol=1e-3)

instantiate_parametrized_tests(TestAttnCp)

if __name__ == "__main__":
    run_tests()