import torch
import torch_npu
import cpp_extension_acs
from torch_npu.testing.testcase import TestCase, run_tests
class Model(torch.nn.Module):
def forward(self, x, out_sin, out_cos):
out_tan = torch.ops.cpp_extension_acs.ascendc_trig(x, out_sin, out_cos)
return out_tan
length = [8, 2048]
class TestCustomTrig(TestCase):
def get_rand_input(self):
x = torch.rand(length, device='npu', dtype=torch.float32)
out_sin = torch.empty_like(x)
out_cos = torch.empty_like(x)
return x, out_sin, out_cos
def test_npugraph(self):
static_x, static_out_sin, static_out_cos = self.get_rand_input()
static_out_tan = torch.rand(length, device='npu', dtype=torch.float32)
g = torch.npu.NPUGraph()
model = Model()
with torch.npu.graph(g):
static_out_tan = model(static_x, static_out_sin, static_out_cos)
real_x, real_out_sin, real_out_cos = self.get_rand_input()
static_x.copy_(real_x)
static_out_sin.copy_(real_out_sin)
static_out_cos.copy_(real_out_cos)
g.replay()
self.check_res(real_x, static_out_sin, static_out_cos, static_out_tan)
def test_make_graphed_callables(self):
model = Model().npu()
x, out_sin, out_cos = self.get_rand_input()
model = torch.npu.make_graphed_callables(model, (x, out_sin, out_cos))
real_x = torch.rand_like(x)
real_out_tan = model(real_x, out_sin, out_cos)
self.check_res(real_x, out_sin, out_cos, real_out_tan)
def test_npugraph_ex_backend(self):
model = Model().npu()
compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)
x, out_sin, out_cos = self.get_rand_input()
out_tan = compiled_model(x, out_sin, out_cos)
self.check_res(x, out_sin, out_cos, out_tan)
def test_trig_inplace_ops(self):
x, out_sin, out_cos = self.get_rand_input()
out_tan = torch.ops.cpp_extension_acs.ascendc_trig(x, out_sin, out_cos)
self.check_res(x, out_sin, out_cos, out_tan)
def check_res(self, x, out_sin, out_cos, out_tan):
cpu_out_sin = torch.sin(x)
cpu_out_cos = torch.cos(x)
cpu_out_tan = torch.tan(x)
self.assertRtolEqual(out_sin, cpu_out_sin)
self.assertRtolEqual(out_cos, cpu_out_cos)
self.assertRtolEqual(out_tan, cpu_out_tan)
if __name__ == "__main__":
run_tests()