import os
import shutil
import torch
import torch_npu.jit
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.utils._path_manager import PathManager
class TestJitOpsFusion(TestCase):
test_jit_model_path = os.path.join(
os.path.realpath(os.path.dirname(__file__)), "test_jit_fusion")
@classmethod
def setUpClass(cls):
PathManager.make_dir_safety(TestJitOpsFusion.test_jit_model_path)
@classmethod
def tearDownClass(cls):
assert os.path.exists(TestJitOpsFusion.test_jit_model_path)
PathManager.remove_path_safety(TestJitOpsFusion.test_jit_model_path)
def test_func_fast_gelu(self):
def ori_func(x):
x = x**2
x = torch.nn.functional.gelu(x)
return x
x = torch.rand(3, 3).npu()
jit_model = torch.jit.trace(ori_func, x)
torch_npu.jit.optimize(jit_model)
match_kinds = [n.kind() == 'npu::fast_gelu' for n in jit_model.graph.nodes()]
self.assertEqual(any(match_kinds), True)
def test_module_fast_gelu(self):
class Ori_Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
x = x**2
x = torch.nn.functional.gelu(x)
return x
x = torch.rand(3, 3).npu()
jit_model = torch.jit.trace(Ori_Module(), x)
torch_npu.jit.optimize(jit_model)
match_kinds = [n.kind() == 'npu::fast_gelu' for n in jit_model.graph.nodes()]
self.assertEqual(any(match_kinds), True)
def test_fast_gelu_result_check(self):
def ori_func(x):
x = x**2
x = torch.nn.functional.gelu(x)
return x
x = torch.rand(3, 3).npu()
jit_model = torch.jit.trace(ori_func, x)
pre_result = jit_model(x)
torch_npu.jit.optimize(jit_model)
model_path = os.path.join(TestJitOpsFusion.test_jit_model_path, 'rewrite.pt')
torch.jit.save(jit_model, model_path)
assert os.path.isfile(model_path)
jit_model = torch.jit.load(model_path)
post_result = jit_model(x)
self.assertAlmostEqual(pre_result, post_result, delta=0.05)
if __name__ == '__main__':
run_tests()