import os
from unittest.mock import patch
import torch
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu
torch._inductor.config.fx_graph_cache = False
os.environ["INDUCTOR_ASCEND_CHECK_ACCURACY"] = "1"
class TestCheckAccuracy(TestUtils):
def test_check_accuracy_1(self):
count_data_dump = 0
count_check_accuracy = 0
def run(x, y):
return F.relu(x) - y
from torch_npu._inductor.npu_triton_heuristics import NPUCachingAutotuner
src_data_dump = NPUCachingAutotuner.data_dump
def wrap_data_dump(self, *args, **kwargs):
status = src_data_dump(self, *args, **kwargs)
if status:
nonlocal count_data_dump
count_data_dump += 1
return status
src_check_accuracy = NPUCachingAutotuner.check_accuracy
def wrap_check_accuracy(self, *args, **kwargs):
status = src_check_accuracy(self, *args, **kwargs)
if status:
nonlocal count_check_accuracy
count_check_accuracy += 1
return status
x = torch.randn(10).npu()
y = torch.randn(10).npu()
g = run(x, y)
run = torch.compile(run)
_ = run(x, y)
with patch.object(NPUCachingAutotuner, "data_dump", wrap_data_dump), \
patch.object(NPUCachingAutotuner, "check_accuracy", wrap_check_accuracy):
self.assertTrue(torch_npu._inductor.config.dump_fx_graph)
self.assertTrue(torch_npu._inductor.config.check_accuracy)
torch_npu._inductor.config.dump_fx_graph = False
torch_npu._inductor.config.check_accuracy = False
z = run(x, y)
self.assertEqual(count_data_dump, 0)
self.assertEqual(count_check_accuracy, 0)
self.assertEqual(z, g)
torch_npu._inductor.config.dump_fx_graph = True
torch_npu._inductor.config.check_accuracy = True
z = run(x, y)
self.assertEqual(count_data_dump, 1)
self.assertEqual(count_check_accuracy, 1)
self.assertEqual(z, g)
if __name__ == "__main__":
run_tests()