import multiprocessing
import os
import torch_npu
from torch_npu.testing.testcase import run_tests, TestCase
import torch
def _wrapper(func, env_config, exit_code):
"""Wrapper function to set environment variable and run test in subprocess."""
os.environ["PYTORCH_NPU_ALLOC_CONF"] = env_config
try:
func()
exit_code.value = 0
except Exception as e:
print(f"Exception: {e}")
exit_code.value = 1
def run_in_subprocess(env_config: str, func, expect_error: bool = False):
"""Run test in subprocess to ensure environment variable takes effect.
PYTORCH_NPU_ALLOC_CONF is parsed during allocator initialization. Setting the
environment variable in the current process won't trigger re-parsing since the
allocator is already initialized. Running in a fresh subprocess ensures the
environment variable is parsed from scratch.
"""
ctx = multiprocessing.get_context("spawn")
exit_code = ctx.Value("i", -1)
p = ctx.Process(target=_wrapper, args=(func, env_config, exit_code))
p.start()
p.join()
if expect_error:
return exit_code.value != 0
else:
return exit_code.value == 0
class TestPerProcessMemoryFraction(TestCase):
"""Test per_process_memory_fraction via PYTORCH_NPU_ALLOC_CONF."""
@staticmethod
def _test_valid_fraction():
total_memory = torch_npu.npu.get_device_properties(0).total_memory
torch_npu.npu.empty_cache()
application = int(total_memory * 0.2)
try:
torch.empty(application, dtype=torch.int8, device="npu")
raise AssertionError("Should have raised OOM")
except RuntimeError as e:
if "out of memory" not in str(e).lower():
raise
def test_valid_fraction_via_env(self):
"""Set per_process_memory_fraction=0.1 via PYTORCH_NPU_ALLOC_CONF.
Verify that memory allocation is limited to 10% of total memory."""
success = run_in_subprocess(
"per_process_memory_fraction:0.1", self._test_valid_fraction
)
self.assertTrue(success, "Subprocess failed")
@staticmethod
def _test_fraction_one():
total_memory = torch_npu.npu.get_device_properties(0).total_memory
torch_npu.npu.empty_cache()
application = int(total_memory * 0.5)
_ = torch.empty(application, dtype=torch.int8, device="npu")
def test_fraction_one_via_env(self):
"""Set per_process_memory_fraction=1.0 via PYTORCH_NPU_ALLOC_CONF.
Should allow full memory usage."""
success = run_in_subprocess(
"per_process_memory_fraction:1.0", self._test_fraction_one
)
self.assertTrue(success, "Subprocess failed")
@staticmethod
def _test_invalid_fraction():
x = torch.empty(1024, device="npu")
def test_invalid_fraction_too_large_rejected(self):
"""per_process_memory_fraction must be <= 1.0."""
has_error = run_in_subprocess(
"per_process_memory_fraction:2.0",
self._test_invalid_fraction,
expect_error=True,
)
self.assertTrue(has_error, "Expected error but subprocess succeeded")
def test_invalid_fraction_negative_rejected(self):
"""per_process_memory_fraction must be >= 0.0."""
has_error = run_in_subprocess(
"per_process_memory_fraction:-0.1",
self._test_invalid_fraction,
expect_error=True,
)
self.assertTrue(has_error, "Expected error but subprocess succeeded")
def test_invalid_fraction_missing_value_rejected(self):
"""per_process_memory_fraction requires a value."""
has_error = run_in_subprocess(
"per_process_memory_fraction:",
self._test_invalid_fraction,
expect_error=True,
)
self.assertTrue(has_error, "Expected error but subprocess succeeded")
@staticmethod
def _test_fraction_with_other_options():
total_memory = torch_npu.npu.get_device_properties(0).total_memory
torch_npu.npu.empty_cache()
application = int(total_memory * 0.4)
_ = torch.empty(application, dtype=torch.int8, device="npu")
def test_fraction_with_other_options(self):
"""Test per_process_memory_fraction combined with other options."""
success = run_in_subprocess(
"expandable_segments:True,per_process_memory_fraction:0.5",
self._test_fraction_with_other_options,
)
self.assertTrue(success, "Subprocess failed")
if __name__ == "__main__":
run_tests()