import torch
import torch_npu
from torch_npu.testing.common_utils import SupportedDevices
from torch_npu.testing.testcase import TestCase, run_tests
class TestAclgraphRngState(TestCase):
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_get_set_rng_state_during_capture(self):
torch.npu.set_device(0)
torch_npu.npu.manual_seed(123)
eager_state = torch_npu.npu.get_rng_state()
eager_first = torch.rand(8, device="npu")
torch_npu.npu.set_rng_state(eager_state)
eager_second = torch.rand(8, device="npu")
eager_next = torch.rand(8, device="npu")
torch_npu.npu.manual_seed(123)
graph = torch.npu.NPUGraph()
with torch.npu.graph(graph):
graph_state = torch_npu.npu.get_rng_state()
graph_first = torch.rand(8, device="npu")
torch_npu.npu.set_rng_state(graph_state)
graph_second = torch.rand(8, device="npu")
torch_npu.npu.manual_seed(123)
graph.replay()
torch.npu.synchronize()
graph_next = torch.rand(8, device="npu")
self.assertEqual(eager_first, eager_second)
self.assertEqual(graph_first, graph_second)
self.assertEqual(graph_first, eager_first)
self.assertEqual(graph_next, eager_next)
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_set_rng_state_with_nonzero_offset_during_capture(self):
torch.npu.set_device(0)
torch_npu.npu.manual_seed(123)
eager_pre = torch.rand(8, device="npu")
eager_state = torch_npu.npu.get_rng_state()
eager_first = torch.rand(8, device="npu")
torch_npu.npu.set_rng_state(eager_state)
eager_second = torch.rand(8, device="npu")
eager_next = torch.rand(8, device="npu")
torch_npu.npu.manual_seed(123)
graph = torch.npu.NPUGraph()
with torch.npu.graph(graph):
graph_pre = torch.rand(8, device="npu")
graph_state = torch_npu.npu.get_rng_state()
graph_first = torch.rand(8, device="npu")
torch_npu.npu.set_rng_state(graph_state)
graph_second = torch.rand(8, device="npu")
torch_npu.npu.manual_seed(123)
graph.replay()
torch.npu.synchronize()
graph_next = torch.rand(8, device="npu")
self.assertEqual(graph_pre, eager_pre)
self.assertEqual(eager_first, eager_second)
self.assertEqual(graph_first, graph_second)
self.assertEqual(graph_first, eager_first)
self.assertEqual(graph_next, eager_next)
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_graph_set_rng_state_seed_mismatch_raises(self):
torch.npu.set_device(0)
torch_npu.npu.manual_seed(0)
torch.rand(1, device="npu")
torch_npu.npu.manual_seed(1)
mismatched_state = torch_npu.npu.get_rng_state()
torch_npu.npu.manual_seed(0)
error = (
"NPUGeneratorImpl::set_current_seed can be called during stream "
"capture only if new seed is the same as the original seed."
)
graph = torch.npu.NPUGraph()
with self.assertRaisesRegex(RuntimeError, error):
with torch.npu.graph(graph):
torch_npu.npu.set_rng_state(mismatched_state)
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_graph_checkpoint_preserve_rng_state(self):
torch.npu.set_device(0)
torch.npu.manual_seed(42)
def fn(x):
return x * torch.sigmoid(torch.randn(1, device="npu"))
fn(torch.ones(1, device="npu"))
torch.npu.manual_seed(42)
eager_in = torch.ones(1, device="npu", requires_grad=True)
eager_out = torch.utils.checkpoint.checkpoint(
fn, eager_in, use_reentrant=False, preserve_rng_state=True
)
(eager_in_grad,) = torch.autograd.grad(eager_out, eager_in)
graph = torch.npu.NPUGraph()
with torch.npu.graph(graph):
graph_in = torch.ones(1, device="npu", requires_grad=True)
graph_out = torch.utils.checkpoint.checkpoint(
fn, graph_in, use_reentrant=False, preserve_rng_state=True
)
(graph_in_grad,) = torch.autograd.grad(graph_out, graph_in)
torch.npu.manual_seed(42)
graph.replay()
torch.npu.synchronize()
self.assertEqual(eager_in_grad, graph_in_grad, prec=0.0)
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_graph_manual_seed_mismatch_raises(self):
torch.npu.set_device(0)
torch.npu.manual_seed(0)
error = (
"NPUGeneratorImpl::set_current_seed can be called during stream "
"capture only if new seed is the same as the original seed."
)
graph = torch.npu.NPUGraph()
with self.assertRaisesRegex(RuntimeError, error):
with torch.npu.graph(graph):
torch.npu.manual_seed(1)
@SupportedDevices(["Ascend910B", "Ascend910_93"])
def test_register_generator_state_under_inference_mode(self):
torch.npu.set_device(0)
generator = torch.Generator(device="npu")
generator.manual_seed(0)
graph = torch.npu.NPUGraph()
with torch.inference_mode():
graph.register_generator_state(generator)
with torch.npu.graph(graph):
graph_out = torch.rand(8, device="npu", generator=generator)
eager_generator = torch.Generator(device="npu")
eager_generator.manual_seed(0)
eager_ref = torch.rand(8, device="npu", generator=eager_generator)
generator.manual_seed(0)
graph.replay()
torch.npu.synchronize()
self.assertEqual(graph_out, eager_ref)
if __name__ == "__main__":
run_tests()