import torch
import torch_npu
import torch._dynamo.test_case
class BytecodeHookTests(torch._dynamo.test_case.TestCase):
def test_bytecode_hook(self):
def fn(a, b):
return a - b * 10
def hook(code, out_code):
return code
torch._dynamo.reset()
handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
try:
opt_fn = torch.compile(fn)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
finally:
handle.remove()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()