import os
from unittest.mock import patch
import torch
import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests
from testutils import TestUtils
import torch_npu
torch._inductor.config.fx_graph_cache = False
os.environ["INDUCTOR_ASCEND_DUMP_FX_GRAPH"] = "1"
class TestForceFallback(TestUtils):
def test_case1(self):
op_list = []
def opoverload_call(self, /, *args, **kwargs):
op_list.append(str(self))
return self._op(*args, **kwargs)
def run(x, y):
return F.relu(x) + y
x = torch.randn(10).npu()
y = torch.randn(10).npu()
g = run(x, y)
run = torch.compile(run)
_ = run(x, y)
with patch.object(torch._ops.OpOverload, "__call__", opoverload_call):
op_list.clear()
z = run(x, y)
self.assertTrue(len(op_list) == 0)
self.assertEqual(z, g)
op_list.clear()
torch_npu._inductor.config.force_fallback_kernel_id = [0]
z = run(x, y)
self.assertTrue("aten.relu.default" in op_list)
self.assertTrue("aten.add.Tensor" in op_list)
self.assertEqual(z, g)
op_list.clear()
torch_npu._inductor.config.force_fallback_kernel_id = 'all'
z = run(x, y)
self.assertTrue("aten.relu.default" in op_list)
self.assertTrue("aten.add.Tensor" in op_list)
self.assertEqual(z, g)
torch_npu._inductor.config.force_fallback_kernel_id = []
if __name__ == "__main__":
run_tests()