import unittest
from itertools import chain
import torch
from torch import nn
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
callback_stream = torch.npu.Stream()
def callback_add(params):
global callback_stream
with torch.npu.stream(callback_stream):
x, y, result = params
result.copy_(x + y)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.result = torch.rand([5, 5]).npu()
def forward(self, graph, x, y):
call_params = [torch.matmul(x, y), torch.matmul(x, y), self.result]
for _ in range(10000):
torch_npu.npu._launch_host_func(torch.npu.current_stream(), callback_add, call_params)
return self.result
class TestAclgraphLaunchHostFunc(TestCase):
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
def test_launch_host_func(self):
torch_npu.npu.set_compile_mode(jit_compile=False)
torch_npu.npu.set_device(0)
self.capture_stream = torch_npu.npu.Stream()
self.graph = torch_npu.npu.NPUGraph()
torch_npu.npu._subscribe_report(self.capture_stream)
a = torch.randn([5, 5]).npu()
b = torch.randn([5, 5]).npu()
model = MyModel()
with torch_npu.npu.stream(self.capture_stream):
with torch_npu.npu.graph(self.graph, stream=self.capture_stream):
self.res = model.forward(self.graph, a, b)
torch.npu.synchronize()
for _ in range(5):
self.graph.replay()
torch.npu.synchronize()
real = torch.matmul(a, b) + torch.matmul(a, b)
self.assertEqual(self.res.cpu(), real.cpu())
torch_npu.npu._unsubscribe_report(self.capture_stream)
if __name__ == '__main__':
run_tests()