from typing import Any
import torch
from torch.library import impl, Library
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
import cpp_extension_full
class TestTorchCompileCustomAdd(TestCase):
def test_add_custom(self):
x = torch.randn([8, 2048], device='npu', dtype=torch.float16)
y = torch.randn([8, 2048], device='npu', dtype=torch.float16)
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
result = cpp_extension_full.ops.add_custom(x, y)
return result
mod = torch.compile(Module().npu(), backend="npugraph_ex")
output = mod(x, y)
self.assertRtolEqual(output, (x + y))
if __name__ == "__main__":
run_tests()