import gc
import multiprocessing
import os
from torch_npu.testing.testcase import run_tests, TestCase
import torch
def _wrapper(func, env_config, exit_code):
"""Wrapper function to set environment variable and run test in subprocess."""
os.environ["PYTORCH_NPU_ALLOC_CONF"] = env_config
try:
func()
exit_code.value = 0
except Exception as e:
print(f"Exception: {e}")
exit_code.value = 1
def run_in_subprocess(env_config: str, func, expect_error: bool = False):
"""Run test in subprocess to ensure environment variable takes effect.
PYTORCH_NPU_ALLOC_CONF is parsed during allocator initialization. Setting the
environment variable in the current process won't trigger re-parsing since the
allocator is already initialized. Running in a fresh subprocess ensures the
environment variable is parsed from scratch.
"""
ctx = multiprocessing.get_context("spawn")
exit_code = ctx.Value("i", -1)
p = ctx.Process(target=_wrapper, args=(func, env_config, exit_code))
p.start()
p.join()
if expect_error:
return exit_code.value != 0
else:
return exit_code.value == 0
class TestLargeSegmentSize(TestCase):
"""Test large_segment_size_mb via PYTORCH_NPU_ALLOC_CONF."""
def test_large_segment_size_default(self):
"""Default large_segment_size is 20MB.
A 5MB allocation should reserve 20MB."""
gc.collect()
torch.npu.empty_cache()
_ = torch.empty(5 * 1024 * 1024 // 4, device="npu", dtype=torch.float32)
reserved = torch.npu.memory_reserved()
self.assertGreaterEqual(reserved, 20 * 1024 * 1024)
@staticmethod
def _test_large_segment_size_via_env():
_ = torch.empty(5 * 1024 * 1024 // 4, device="npu", dtype=torch.float32)
reserved = torch.npu.memory_reserved()
print(f"reserved={reserved}")
assert reserved >= 50 * 1024 * 1024, f"Expected >= 50MB, got {reserved}"
def test_large_segment_size_via_env(self):
"""Set large_segment_size_mb=50 via PYTORCH_NPU_ALLOC_CONF.
Verify that a mid-range allocation reserves in 50MB granularity."""
success = run_in_subprocess(
"large_segment_size_mb:50", self._test_large_segment_size_via_env
)
self.assertTrue(success, "Subprocess failed")
@staticmethod
def _test_large_segment_size_mb_priority():
_ = torch.empty(5 * 1024 * 1024 // 4, device="npu", dtype=torch.float32)
reserved = torch.npu.memory_reserved()
print(f"reserved={reserved}")
assert reserved >= 100 * 1024 * 1024, f"Expected >= 100MB, got {reserved}"
assert reserved < 150 * 1024 * 1024, f"Expected < 150MB, got {reserved}"
def test_large_segment_size_mb_priority(self):
"""When both segment_size_mb and large_segment_size_mb are set,
large_segment_size_mb takes priority."""
success = run_in_subprocess(
"expandable_segments:True,segment_size_mb:50,large_segment_size_mb:100",
self._test_large_segment_size_mb_priority,
)
self.assertTrue(success, "Subprocess failed")
@staticmethod
def _test_invalid_large_segment_size():
_ = torch.empty(5 * 1024 * 1024 // 4, device="npu", dtype=torch.float32)
print("Should not reach here")
def test_invalid_large_segment_size_rejected(self):
"""large_segment_size_mb must be > 10MB (kMinLargeAlloc)."""
has_error = run_in_subprocess(
"large_segment_size_mb:5",
self._test_invalid_large_segment_size,
expect_error=True,
)
self.assertTrue(has_error, "Expected error but subprocess succeeded")
@staticmethod
def _test_max_split_size_less_than_large_segment():
_ = torch.empty(5 * 1024 * 1024 // 4, device="npu", dtype=torch.float32)
print("Should not reach here")
def test_max_split_size_less_than_large_segment_rejected(self):
"""max_split_size_mb must be >= large_segment_size_mb."""
has_error = run_in_subprocess(
"large_segment_size_mb:50,max_split_size_mb:30",
self._test_max_split_size_less_than_large_segment,
expect_error=True,
)
self.assertTrue(has_error, "Expected error but subprocess succeeded")
if __name__ == "__main__":
run_tests()