import os
import platform
import shutil
import subprocess
import unittest
import torch
import torch.utils.cpp_extension
import torch_npu
from torch_npu.testing.testcase import run_tests, TestCase
PYTORCH_INSTALL_PATH = os.path.dirname(os.path.realpath(torch.__file__))
PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.realpath(torch_npu.__file__))
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
def setup_sanitizer():
os.environ["TORCH_NPU_SANITIZER"] = "1"
import torch_npu.npu._sanitizer as sanitizer
if not sanitizer.npu_sanitizer.enabled:
sanitizer.npu_sanitizer.enable()
def get_event_handler():
import torch_npu.npu._sanitizer as sanitizer
return sanitizer.npu_sanitizer.event_handler
def create_build_path(build_directory):
if os.path.exists(build_directory):
shutil.rmtree(build_directory, ignore_errors=True)
os.makedirs(build_directory, exist_ok=True)
def build_stub(base_dir):
build_stub_cmd = [
"sh",
os.path.join(base_dir, "third_party/acl/libs/build_stub.sh"),
]
if subprocess.call(build_stub_cmd) != 0:
raise RuntimeError(f"Failed to build stub: {build_stub_cmd}")
def reset_sanitizer():
import torch_npu.npu._sanitizer as sanitizer
if sanitizer.npu_sanitizer.dispatch is not None:
try:
sanitizer.npu_sanitizer.dispatch.__exit__(None, None, None)
except Exception:
pass
sanitizer.npu_sanitizer.dispatch = None
sanitizer.npu_sanitizer.event_handler = None
sanitizer.npu_sanitizer.enabled = False
@unittest.skipUnless(IS_ARM64, "Only working on ARM")
class TestSanitizerPluggableAllocator(TestCase):
module = None
build_directory = os.path.join("allocator", "build_sanitizer_pluggable")
@classmethod
def setUpClass(cls):
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(TEST_DIR)
build_stub(BASE_DIR)
create_build_path(cls.build_directory)
CANN_LIB_PATH = os.path.join(BASE_DIR, "third_party/acl/libs")
extra_ldflags = []
extra_ldflags.append("-lascendcl")
extra_ldflags.append(f"-L{CANN_LIB_PATH}")
extra_ldflags.append("-lc10")
extra_ldflags.append(f"-L{PYTORCH_INSTALL_PATH}")
extra_include_paths = [os.path.join(TEST_DIR, "cpp_extensions")]
extra_include_paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, "include"))
extra_include_paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, 'include', 'third_party', 'acl', 'inc'))
cls.module = torch.utils.cpp_extension.load(
name="sanitizer_pluggable_allocator_extensions",
sources=[os.path.join(TEST_DIR, "cpp_extensions", "pluggable_allocator_extensions.cpp")],
extra_include_paths=extra_include_paths,
extra_cflags=["-g"],
extra_ldflags=extra_ldflags,
build_directory=cls.build_directory,
verbose=True,
)
def test_pluggable_allocator_record_stream_warning_and_suppression(self):
"""Pluggable allocator should support both missing-record_stream warning and suppression."""
os_path = os.path.join(
self.build_directory, "sanitizer_pluggable_allocator_extensions.so"
)
allocator = torch_npu.npu.memory.NPUPluggableAllocator(
os_path, "my_malloc", "my_free"
)
torch_npu.npu.memory.change_current_allocator(allocator)
setup_sanitizer()
x = torch.randn(100, device="npu")
stream1 = torch_npu.npu.Stream()
default_stream = torch_npu.npu.default_stream()
stream1.wait_stream(default_stream)
with torch_npu.npu.stream(stream1):
_ = x + 1
warnings = get_event_handler().flush_record_stream_warnings()
self.assertGreater(
len(warnings),
0,
"Missing record_stream should be detected with NPUPluggableAllocator.",
)
reset_sanitizer()
setup_sanitizer()
y = torch.randn(100, device="npu")
stream2 = torch_npu.npu.Stream()
default_stream = torch_npu.npu.default_stream()
y.record_stream(stream2)
stream2.wait_stream(default_stream)
with torch_npu.npu.stream(stream2):
_ = y + 1
warnings = get_event_handler().flush_record_stream_warnings()
self.assertEqual(
len(warnings),
0,
"record_stream should suppress missing-record_stream warning with NPUPluggableAllocator.",
)
if __name__ == "__main__":
run_tests()