import torch
from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from testutils import TestUtils
import torch_npu


class TestArgmax(TestUtils):
    def argmax(self, a, dim):
        return torch.argmax(a, dim)

    def test_argmax(self):
        shape = (512, 64)
        dim = -1
        a = torch.randn(shape, requires_grad=False, dtype=torch.float32, device='npu')

        argmax_triton = torch.compile(self.argmax, backend="inductor", dynamic=False)
        r = self.argmax(a, dim)
        r1 = argmax_triton(a, dim)
        self.assertEqual(r, r1, atol=1e-3, rtol=1e-3)

instantiate_parametrized_tests(TestArgmax)

if __name__ == "__main__":
    run_tests()