"""
Add validation cases for selected torch.fx.graph internal APIs on NPU.
Current covered APIs:
- torch.fx.graph._format_target
- torch.fx.graph._is_from_torch
- torch.fx.graph._origin_type_map.get
- torch.fx.graph._register_custom_builtin
Note:
CodeGen._gen_python_code internally defines a local helper named _format_args.
It is not a torch.fx.graph module-level callable, so it is not added as a
separate test case in this file.
"""
import torch_npu
import torch
import torch.fx.graph as fx_graph
from torch.testing._internal.common_utils import TestCase, run_tests
def _test_fx_graph_custom_builtin_for_npu(x):
return x
class TestFxGraphInternal(TestCase):
def _remove_custom_builtin(self, name):
fx_graph._custom_builtins.pop(name, None)
fx_graph._illegal_names.pop(name, None)
def test_format_target(self):
self.assertEqual(
fx_graph._format_target("root", "foo.bar"),
"root.foo.bar",
)
self.assertEqual(
fx_graph._format_target("root", "foo.0"),
'getattr(root.foo, "0")',
)
def test_is_from_torch(self):
self.assertTrue(fx_graph._is_from_torch(torch.add))
self.assertTrue(fx_graph._is_from_torch(torch.relu))
def user_defined_func(x):
return x
self.assertFalse(fx_graph._is_from_torch(user_defined_func))
def test_origin_type_map_get(self):
self.assertEqual(fx_graph._origin_type_map.get(list).__origin__, list)
self.assertEqual(fx_graph._origin_type_map.get(dict).__origin__, dict)
self.assertEqual(fx_graph._origin_type_map.get(set).__origin__, set)
self.assertEqual(fx_graph._origin_type_map.get(tuple).__origin__, tuple)
self.assertIsNone(fx_graph._origin_type_map.get(TestFxGraphInternal))
def test_register_custom_builtin(self):
name = "_test_fx_graph_custom_builtin_for_npu"
import_str = (
f"from {__name__} import _test_fx_graph_custom_builtin_for_npu"
)
obj = _test_fx_graph_custom_builtin_for_npu
self._remove_custom_builtin(name)
fx_graph._register_custom_builtin(name, import_str, obj)
self.assertIn(name, fx_graph._custom_builtins)
self.assertEqual(
fx_graph._custom_builtins[name].import_str,
import_str,
)
self.assertIs(fx_graph._custom_builtins[name].obj, obj)
self.assertIs(fx_graph._illegal_names[name], obj)
self._remove_custom_builtin(name)
if __name__ == "__main__":
run_tests()