"""
add validation cases for torch.jit.ScriptFunction APIs:
torch.jit.ScriptFunction
torch.jit.ScriptFunction.get_debug_state
torch.jit.ScriptFunction.save
torch.jit.ScriptFunction.save_to_buffer
"""
import io
import tempfile
import os
import torch
import unittest
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
ProfilingMode,
)
class TestScriptFunctionAPI(TestCase):
def setUp(self):
super().setUp()
@torch.jit.script
def test_fn(x: torch.Tensor) -> torch.Tensor:
return x * 2 + 1
self.script_fn = test_fn
self.dummy_input = torch.tensor([1.0, 2.0, 3.0]).npu()
def test_script_function_type(self):
"""Verify the object is a torch.jit.ScriptFunction."""
self.assertIsInstance(self.script_fn, torch.jit.ScriptFunction)
@unittest.skipIf(
GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"get_debug_state requires profiling graph executor",
)
def test_get_debug_state(self):
"""Verify get_debug_state returns a valid object under profiling mode."""
with enable_profiling_mode_for_profiling_tests():
self.script_fn(self.dummy_input)
self.script_fn(self.dummy_input)
state = self.script_fn.get_debug_state()
self.assertIsNotNone(state)
self.assertTrue(hasattr(state, "execution_plans"))
def test_save_to_buffer(self):
"""Test save_to_buffer returns bytes and can be loaded back."""
buf = self.script_fn.save_to_buffer()
self.assertIsInstance(buf, bytes)
loaded = torch.jit.load(io.BytesIO(buf))
expected = self.script_fn(self.dummy_input)
actual = loaded(self.dummy_input)
self.assertTrue(torch.equal(expected, actual))
def test_save_to_file(self):
"""Test save with a file path and reload."""
with tempfile.TemporaryDirectory() as tmpdir:
fname = os.path.join(tmpdir, "script_fn.pt")
self.script_fn.save(fname)
loaded = torch.jit.load(fname)
expected = self.script_fn(self.dummy_input)
actual = loaded(self.dummy_input)
self.assertTrue(torch.equal(expected, actual))
if __name__ == "__main__":
run_tests()