import torch
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu
class Test_issue70(TestUtils):
def op_forward(self, x):
return x.mean(-1)
def test_issue70(self):
compiled_net = torch.compile(self.op_forward, backend="inductor")
arg = torch.randn((1, 1, 7168)).npu()
output = self.op_forward(arg)
output1 = compiled_net(arg)
self.assertEqual(output, output1, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
run_tests()