import unittest
import torch
import torch_npu
class TraceStreamEventTests(unittest.TestCase):
def test_dynamo_trace_stream_event(self):
def my_backend(gm, example_inputs):
print(gm.graph)
node_names = (node.name for node in gm.graph.nodes)
self.assertIn("current_stream", node_names)
self.assertIn("set_stream", node_names)
self.assertIn("record_stream", node_names)
return gm
@torch.compile(backend=my_backend)
def test_stream_in_graph(a):
s = torch.npu.Stream()
event = torch.npu.Event()
r = torch.add(a, 2)
event.record()
with torch.npu.stream(s):
event.wait()
r = torch.add(r, 1)
r.record_stream(s)
r = torch.add(r, 1)
return r
i = torch.randn([3, 3], device="npu:0")
r = test_stream_in_graph(i)
return r
if __name__ == '__main__':
unittest.main()