"""
Add validation cases for torch.fx Graph APIs on NPU:
1. PyTorch community lacks sufficient and direct API validations for some APIs, so this file is added.
2. This file validates torch.fx.Graph.inserting_after,
torch.fx.Graph.inserting_before, torch.fx.graph.magic_methods.format,
torch.fx.graph.inplace_methods.format,
torch.fx.Graph.find_nodes, torch.fx.Graph.graph_copy,
torch.fx.Graph.erase_node, and torch.fx.Graph.get_attr.
"""
import operator
import torch
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.fx import GraphModule, symbolic_trace
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
class TestFxGraphApiNPU(TestCase):
def test_graph_inserting_after(self):
graph = torch.fx.Graph()
x = graph.placeholder("x")
neg = graph.call_function(torch.neg, (x,))
with graph.inserting_after(neg):
relu = graph.call_function(torch.relu, (neg,))
graph.output(relu)
graph.lint()
nodes = list(graph.nodes)
self.assertEqual(nodes.index(relu), nodes.index(neg) + 1)
self.assertEqual(neg.args, (x,))
self.assertEqual(relu.args, (neg,))
self.assertEqual(relu.target, torch.relu)
gm = GraphModule(torch.nn.Module(), graph)
input_tensor = torch.randn(2, 3).to(device_type)
self.assertEqual(gm(input_tensor), torch.relu(torch.neg(input_tensor)))
def test_graph_inserting_before(self):
graph = torch.fx.Graph()
x = graph.placeholder("x")
relu = graph.call_function(torch.relu, (x,))
graph.output(relu)
with graph.inserting_before(relu):
neg = graph.call_function(torch.neg, (x,))
relu.args = (neg,)
graph.lint()
nodes = list(graph.nodes)
self.assertLess(nodes.index(neg), nodes.index(relu))
self.assertEqual(nodes.index(neg), nodes.index(relu) - 1)
self.assertEqual(neg.args, (x,))
self.assertEqual(relu.args, (neg,))
self.assertEqual(neg.target, torch.neg)
gm = GraphModule(torch.nn.Module(), graph)
input_tensor = torch.randn(2, 3).to(device_type)
self.assertEqual(gm(input_tensor), torch.relu(torch.neg(input_tensor)))
def test_magic_methods_format_codegen(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x << 3, x >> 3
input_tensor = torch.LongTensor(10).random_(0, 1024).to(device_type)
gm = symbolic_trace(MyModule())
gm.graph.lint()
nodes = list(gm.graph.nodes)
x_node = next(node for node in nodes if node.op == "placeholder")
lshift_node = next(node for node in nodes if node.target == operator.lshift)
rshift_node = next(node for node in nodes if node.target == operator.rshift)
self.assertEqual(lshift_node.args, (x_node, 3))
self.assertEqual(rshift_node.args, (x_node, 3))
self.assertIn("x << 3", gm.code)
self.assertIn("x >> 3", gm.code)
expected = MyModule()(input_tensor)
self.assertEqual(gm(input_tensor), expected)
def test_inplace_methods_format_codegen(self):
graph = torch.fx.Graph()
a = graph.placeholder("a")
b = graph.placeholder("b")
imul = graph.call_function(operator.imul, (a, b), {})
graph.output(a)
graph.lint()
self.assertEqual(imul.args, (a, b))
self.assertEqual(imul.target, operator.imul)
gm = GraphModule(torch.nn.Module(), graph)
gm.recompile()
self.assertIn("a *= b", gm.code)
input_tensor = torch.ones(2, 3).to(device_type)
scale = torch.full((2, 3), 3.0).to(device_type)
output = gm(input_tensor, scale)
self.assertEqual(output, torch.full((2, 3), 3.0).to(device_type))
self.assertEqual(input_tensor, torch.full((2, 3), 3.0).to(device_type))
def test_find_nodes(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.p = torch.nn.Parameter(torch.tensor([1.0]))
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return torch.relu(x) + self.linear(x) * self.p
m = M().to(device_type)
gm = symbolic_trace(m)
relu_nodes = gm.graph.find_nodes(op="call_function", target=torch.relu)
self.assertEqual(len(relu_nodes), 1)
self.assertEqual(relu_nodes[0].op, "call_function")
module_nodes = gm.graph.find_nodes(op="call_module")
self.assertEqual(len(module_nodes), 1)
self.assertEqual(module_nodes[0].target, "linear")
attr_nodes = gm.graph.find_nodes(op="get_attr")
self.assertEqual(len(attr_nodes), 1)
self.assertEqual(attr_nodes[0].target, "p")
sorted_relu = gm.graph.find_nodes(op="call_function", target=torch.relu, sort=True)
unsorted_relu = gm.graph.find_nodes(op="call_function", target=torch.relu, sort=False)
self.assertEqual(set(sorted_relu), set(unsorted_relu))
self.assertEqual(gm.graph.find_nodes(op="call_function", target=torch.sigmoid), [])
input_tensor = torch.randn(2, 3).to(device_type)
self.assertEqual(gm(input_tensor), m(input_tensor))
def test_graph_copy(self):
g_src = torch.fx.Graph()
x = g_src.placeholder("x")
neg = g_src.call_function(torch.neg, (x,))
relu = g_src.call_function(torch.relu, (neg,))
g_src.output((relu,))
g_dst = torch.fx.Graph()
val_map = {}
g_dst.graph_copy(g_src, val_map)
self.assertEqual(len(val_map), 3)
x_dst, neg_dst, relu_dst = val_map[x], val_map[neg], val_map[relu]
self.assertEqual(neg_dst.args, (x_dst,))
self.assertEqual(relu_dst.args, (neg_dst,))
g_dst2 = torch.fx.Graph()
rv = g_dst2.graph_copy(g_src, {}, return_output_node=True)
self.assertIsNotNone(rv)
_, src_output = rv
self.assertEqual(src_output.op, "output")
g_dst3 = torch.fx.Graph()
existing = g_dst3.placeholder("existing")
val_map3 = {x: existing}
g_dst3.graph_copy(g_src, val_map3)
self.assertEqual(val_map3[neg].args[0], existing)
def test_erase_node(self):
graph = torch.fx.Graph()
x = graph.placeholder("x")
neg = graph.call_function(torch.neg, (x,))
relu = graph.call_function(torch.relu, (neg,))
graph.output(relu)
neg.replace_all_uses_with(x)
graph.erase_node(neg)
self.assertNotIn(neg, list(graph.nodes))
self.assertEqual(relu.args, (x,))
graph.lint()
gm = GraphModule(torch.nn.Module(), graph)
input_tensor = torch.randn(2, 3).to(device_type)
self.assertEqual(gm(input_tensor), torch.relu(input_tensor))
def test_get_attr(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor([2.0]))
self.bias = torch.nn.Parameter(torch.tensor([1.0]))
def forward(self, x):
return x * self.weight + self.bias
gm = symbolic_trace(M())
attr_nodes = gm.graph.find_nodes(op="get_attr")
self.assertEqual(len(attr_nodes), 2)
targets = {n.target for n in attr_nodes}
self.assertIn("weight", targets)
self.assertIn("bias", targets)
gm = gm.to(device_type)
input_tensor = torch.randn(2, 3).to(device_type)
self.assertEqual(gm(input_tensor), M().to(device_type)(input_tensor))
if __name__ == "__main__":
run_tests()