"""
Feature: ACLGraph supports pin_memory during graph capture.
Description: Test that pin_memory() works correctly when called inside an
ACLGraph capture region.
Expectation: Capture-time pinned blocks are isolated from the default host
pool, replays see original captured data.
"""
import gc
import os
import subprocess
import sys
import textwrap
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestAclgraphPinMemory(TestCase):
def setUp(self):
super().setUp()
torch.npu.set_device(0)
gc.collect()
torch_npu.npu.host_empty_cache()
torch_npu.npu.reset_peak_host_memory_stats()
torch_npu.npu.reset_accumulated_host_memory_stats()
def tearDown(self):
gc.collect()
torch_npu.npu.host_empty_cache()
super().tearDown()
@staticmethod
def _warmup(stream):
with torch.npu.stream(stream):
_ = torch.ones(64, device="npu") * 2
torch.npu.synchronize()
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
@torch.no_grad()
def test_pin_memory_h2d_during_capture(self):
"""End-to-end: capture H2D from a pinned tensor, replay must reproduce data."""
s = torch.npu.Stream()
self._warmup(s)
g = torch.npu.NPUGraph()
gpu_dst = torch.zeros(64, device="npu")
with torch.npu.graph(g, stream=s):
pinned = torch.full((64,), 7.0).pin_memory()
gpu_dst.copy_(pinned, non_blocking=True)
self.assertTrue(pinned.is_pinned())
g.replay()
torch.npu.synchronize()
self.assertEqual(gpu_dst.cpu(), torch.full((64,), 7.0))
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
@torch.no_grad()
def test_captured_pinned_block_isolated_from_default_pool(self):
"""After del of a capture-time pinned tensor, the block must stay in the
private pool (host_empty_cache cannot free it) and must not be handed
out to post-capture pin_memory() allocations. Replay must keep reading
the original captured data."""
s = torch.npu.Stream()
self._warmup(s)
g = torch.npu.NPUGraph()
gpu_dst = torch.zeros(64, device="npu")
with torch.npu.graph(g, stream=s):
pinned = torch.full((64,), 7.0).pin_memory()
gpu_dst.copy_(pinned, non_blocking=True)
g.replay()
torch.npu.synchronize()
self.assertEqual(gpu_dst.cpu(), torch.full((64,), 7.0))
captured_ptr = pinned.data_ptr()
self.assertNotEqual(captured_ptr, 0)
del pinned
gc.collect()
before = torch_npu.npu.host_memory_stats()
torch_npu.npu.host_empty_cache()
after = torch_npu.npu.host_memory_stats()
self.assertEqual(
before["allocations.current"], after["allocations.current"],
"captured pinned block was freed by host_empty_cache — private pool "
"isolation is broken",
)
self.assertEqual(
before["allocated_bytes.current"], after["allocated_bytes.current"],
)
scramble = [torch.full((64,), float(-i - 1)).pin_memory()
for i in range(32)]
self.assertNotIn(
captured_ptr, {t.data_ptr() for t in scramble},
"captured pinned block was reused by post-capture pin_memory()",
)
for _ in range(5):
g.replay()
torch.npu.synchronize()
self.assertEqual(
gpu_dst.cpu(), torch.full((64,), 7.0),
"replay read corrupted data — captured block was recycled",
)
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
@torch.no_grad()
def test_intra_capture_block_reuse_stays_in_private_pool(self):
"""Inside a single capture, allocate/del/allocate of pinned blocks must
work and both captured H2D copies must replay correctly."""
s = torch.npu.Stream()
self._warmup(s)
g = torch.npu.NPUGraph()
dst1 = torch.zeros(64, device="npu")
dst2 = torch.zeros(64, device="npu")
with torch.npu.graph(g, stream=s):
t1 = torch.full((64,), 3.0).pin_memory()
dst1.copy_(t1, non_blocking=True)
del t1
t2 = torch.full((64,), 5.0).pin_memory()
dst2.copy_(t2, non_blocking=True)
self.assertTrue(t2.is_pinned())
t2_ptr = t2.data_ptr()
del t2
gc.collect()
post = [torch.full((64,), -1.0).pin_memory() for _ in range(16)]
self.assertNotIn(t2_ptr, {p.data_ptr() for p in post})
g.replay()
torch.npu.synchronize()
self.assertEqual(dst1.cpu(), torch.full((64,), 3.0))
self.assertEqual(dst2.cpu(), torch.full((64,), 5.0))
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
@torch.no_grad()
def test_two_graphs_shared_pool_no_cross_contamination(self):
"""Two graphs sharing a pool: each captured H2D must replay its own data."""
s = torch.npu.Stream()
self._warmup(s)
pool = torch.npu.graph_pool_handle()
g1 = torch.npu.NPUGraph()
g2 = torch.npu.NPUGraph()
d1 = torch.zeros(64, device="npu")
d2 = torch.zeros(64, device="npu")
with torch.npu.graph(g1, stream=s, pool=pool):
t1 = torch.full((64,), 11.0).pin_memory()
d1.copy_(t1, non_blocking=True)
with torch.npu.graph(g2, stream=s, pool=pool):
t2 = torch.full((64,), 22.0).pin_memory()
d2.copy_(t2, non_blocking=True)
self.assertTrue(t1.is_pinned())
self.assertTrue(t2.is_pinned())
self.assertNotEqual(t1.data_ptr(), t2.data_ptr())
for _ in range(3):
g2.replay()
g1.replay()
g1.replay()
g2.replay()
torch.npu.synchronize()
self.assertEqual(d1.cpu(), torch.full((64,), 11.0))
self.assertEqual(d2.cpu(), torch.full((64,), 22.0))
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
@torch.no_grad()
def test_release_pool_frees_unused_captured_blocks(self):
"""Capture-time pinned blocks that were never used by any captured op
must be reclaimable after g.reset() + host_empty_cache().
"""
s = torch.npu.Stream()
self._warmup(s)
gc.collect()
torch_npu.npu.host_empty_cache()
baseline = torch_npu.npu.host_memory_stats()
N = 10
for _ in range(N):
g = torch.npu.NPUGraph()
with torch.npu.graph(g, stream=s):
t = torch.full((1024,), 1.0).pin_memory()
del t
g.reset()
del g
gc.collect()
torch_npu.npu.host_empty_cache()
final = torch_npu.npu.host_memory_stats()
self.assertEqual(
final["allocations.current"], baseline["allocations.current"],
f"unused captured pinned blocks leaked "
f"{final['allocations.current'] - baseline['allocations.current']} "
f"over {N} capture/reset cycles — release_pool did not hand blocks "
f"back to the default pool, or empty_cache did not drain them",
)
self.assertEqual(
final["allocated_bytes.current"],
baseline["allocated_bytes.current"],
)
@SupportedDevices(['Ascend910B', 'Ascend910_93'])
def test_capture_rejects_expandable_segments(self):
"""pin_memory_expandable_segments=True must be rejected at capture_begin.
Runs in a subprocess because setAllocator() uses std::call_once and
cannot be reconfigured within the current process. Requires CANN
driver >= 25.5.0 to actually enable the expandable segments path;
on older drivers the setting is silently downgraded and this path
cannot be exercised."""
probe = textwrap.dedent(
"""
import os
os.environ['PYTORCH_NPU_ALLOC_CONF'] = 'pin_memory_expandable_segments:True'
import torch, torch_npu
torch.npu.set_device(0)
# Touch allocator so config parsing runs and any driver downgrade
# warning is emitted before we attempt capture.
_ = torch.empty(1, device='npu')
g = torch.npu.NPUGraph()
s = torch.npu.Stream()
try:
with torch.npu.graph(g, stream=s):
pass
except RuntimeError as e:
msg = str(e)
if 'pin_memory_expandable_segments' in msg:
print('OK_REJECT')
else:
print('OTHER_ERROR:' + msg)
else:
print('NO_ERROR_RAISED')
"""
).strip()
env = os.environ.copy()
env['PYTORCH_NPU_ALLOC_CONF'] = 'pin_memory_expandable_segments:True'
r = subprocess.run(
[sys.executable, '-c', probe],
capture_output=True, text=True, timeout=180, env=env,
)
combined = (r.stdout or '') + (r.stderr or '')
driver_downgraded = (
'pin_memory_expandable_segments setting failure' in combined
or 'does not support this feature' in combined
)
if driver_downgraded and 'OK_REJECT' not in r.stdout:
self.skipTest(
"CANN driver does not support pin_memory_expandable_segments "
"(needs >= 25.5.0); cannot exercise capture_begin rejection path"
)
self.assertIn('OK_REJECT', r.stdout,
msg=f"stdout={r.stdout!r}\nstderr={r.stderr!r}")
if __name__ == "__main__":
run_tests()