import torch
import torch_npu
import cpp_extension_acs
from torch_npu.testing.testcase import TestCase, run_tests
from torch.library import Library
class Model(torch.nn.Module):
def forward(self, x, y):
return torch.ops.cpp_extension_acs.ascendc_add(x, y)
length = [8, 2048]
class TestCustomAdd(TestCase):
def get_rand_input(self):
x = torch.randint(low=1, high=100, size=length, device='npu', dtype=torch.int)
y = torch.randint(low=1, high=100, size=length, device='npu', dtype=torch.int)
return x, y
def test_npugraph(self):
static_x, static_y = self.get_rand_input()
static_target = torch.randint(low=1, high=100, size=length, device='npu:0', dtype=torch.int)
g = torch.npu.NPUGraph()
model = Model()
with torch.npu.graph(g):
static_target = model(static_x, static_y)
real_x, real_y = self.get_rand_input()
static_x.copy_(real_x)
static_y.copy_(real_y)
g.replay()
cpuout = torch.add(real_x, real_y)
self.assertEqual(static_target, cpuout)
def test_make_graphed_callables(self):
model = Model().npu()
x, y = self.get_rand_input()
model = torch.npu.make_graphed_callables(model, (x, y))
real_x = torch.randint_like(x, low=1, high=100)
real_y = torch.randint_like(y, low=1, high=100)
output = model(real_x, real_y)
cpuout = torch.add(real_x, real_y)
self.assertEqual(output, cpuout)
def test_npugraph_ex_backend(self):
model = Model().npu()
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
x, y = self.get_rand_input()
output = compiled_model(x, y)
cpuout = torch.add(x, y)
self.assertEqual(output, cpuout)
def test_add_custom_ops(self):
x, y = self.get_rand_input()
output = torch.ops.cpp_extension_acs.ascendc_add(x.npu(), y.npu()).cpu()
cpuout = torch.add(x, y)
self.assertEqual(output, cpuout)
if __name__ == "__main__":
run_tests()