import os
import time
import multiprocessing
import shutil
import unittest
import platform
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
try:
multiprocessing.set_start_method('spawn')
except RuntimeError:
pass
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
def extract_aclrtQueryEventStatus_count(prof_dir):
"""
Extract the call count of aclrtQueryEventStatus from profiler results.
Uses Linux system commands (find/grep/awk) to parse api_statistic.csv.
Args:
prof_dir: str, path to the profiling result directory
Returns:
count: int, call count of aclrtQueryEventStatus, 0 if not found
"""
import subprocess
count = 0
try:
find_result = subprocess.run(
["find", prof_dir, "-name", "api_statistic.csv", "-type", "f"],
capture_output=True, text=True, timeout=30
)
if find_result.returncode == 0 and find_result.stdout.strip():
csv_file = find_result.stdout.strip().split('\n')[0]
grep_result = subprocess.run(
["grep", "aclrtQueryEventStatus", csv_file],
capture_output=True, text=True, timeout=10
)
if grep_result.returncode == 0 and grep_result.stdout.strip():
awk_result = subprocess.run(
["awk", "-F,", '{print $5}'],
input=grep_result.stdout,
capture_output=True, text=True, timeout=10
)
if awk_result.returncode == 0 and awk_result.stdout.strip():
count = int(awk_result.stdout.strip())
except (subprocess.TimeoutExpired, subprocess.CalledProcessError, ValueError):
pass
return count
def run_matmul_with_profiling(result_queue, enable_lazy_reclaim):
"""
Run matmul test in a separate process and return the aclrtQueryEventStatus call count.
Args:
result_queue: multiprocessing.Queue, used to return results
enable_lazy_reclaim: bool, whether to enable lazy reclaim feature
"""
if enable_lazy_reclaim:
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "multi_stream_lazy_reclaim:True"
else:
if "PYTORCH_NPU_ALLOC_CONF" in os.environ:
del os.environ["PYTORCH_NPU_ALLOC_CONF"]
torch.npu.init()
prof_dir = os.path.abspath("./prof_" + str(enable_lazy_reclaim))
try:
stream0 = torch.npu.Stream()
stream1 = torch.npu.Stream()
stream2 = torch.npu.Stream()
experimental_config = torch_npu.profiler._ExperimentalConfig(
profiler_level=torch_npu.profiler.ProfilerLevel.Level2)
with torch_npu.profiler.profile(
activities=[torch_npu.profiler.ProfilerActivity.NPU,
torch_npu.profiler.ProfilerActivity.CPU],
with_stack=False,
record_shapes=False,
profile_memory=False,
schedule=torch_npu.profiler.schedule(
wait=0, warmup=0, active=1, repeat=1, skip_first=0),
experimental_config=experimental_config,
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(prof_dir)) as prof:
a = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
b = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
c = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
d = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
e = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
f = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
a.record_stream(stream0)
a.record_stream(stream1)
a.record_stream(stream2)
for _ in range(50):
with torch.npu.stream(stream0):
torch.matmul(a, b, out=c)
with torch.npu.stream(stream1):
torch.matmul(a, b, out=d)
with torch.npu.stream(stream2):
torch.matmul(a, b, out=e)
a = None
f = None
for _ in range(10):
tmp = torch.empty((1024, 1024), dtype=torch.float32, device="npu")
torch.npu.synchronize()
a1 = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
f1 = torch.empty((2048, 2048), dtype=torch.float32, device="npu")
prof.step()
count = extract_aclrtQueryEventStatus_count(prof_dir)
result_queue.put(("success", count))
except Exception as e:
result_queue.put(("error", str(e)))
finally:
if os.path.exists(prof_dir):
shutil.rmtree(prof_dir)
@unittest.skipUnless(IS_ARM64, "Only working on ARM")
class TestMultiStreamLazyReclaim(TestCase):
"""
Test the reduction effect of multi_stream_lazy_reclaim feature on event query counts.
Principle:
- eager reclaim mode: Calls process_events to query event status before every memory allocation
- lazy reclaim mode: Only queries in the following cases:
1. No available memory block found (!block_found)
2. Event queue exceeds threshold kLazyQuerySize (512)
Test Method:
Use multiprocessing to test in two separate processes:
- Process 1: Enable multi_stream_lazy_reclaim
- Process 2: Disable multi_stream_lazy_reclaim
Each process sets environment variables independently to ensure configuration takes effect.
"""
def test_lazy_reclaim_reduces_event_queries_counts(self):
"""
Compare aclrtQueryEventStatus call counts between eager reclaim and lazy reclaim modes.
Validation Goal:
Lazy reclaim mode should significantly reduce the number of aclrtQueryEventStatus calls.
"""
configs = [
("eager", False),
("lazy", True)
]
results = {}
for name, enable_lazy in configs:
print(f"\n--- Starting {name} reclaim test ---")
queue = multiprocessing.Queue()
process = multiprocessing.Process(
target=run_matmul_with_profiling,
args=(queue, enable_lazy)
)
process.start()
process.join(timeout=300)
if process.is_alive():
process.terminate()
process.join()
self.fail(f"{name} reclaim process timed out and was terminated.")
status, result = queue.get()
self.assertEqual(status, "success", f"{name} reclaim process failed: {result}")
print(f"---mode {name}------count:{result}")
results[name] = result
time.sleep(2)
eager_counts = results["eager"]
lazy_counts = results["lazy"]
print(f"\n========== Event Query Count Comparison ==========")
print(f"Eager reclaim (multi_stream_lazy_reclaim:False): {eager_counts}")
print(f"Lazy reclaim (multi_stream_lazy_reclaim:True): {lazy_counts}")
self.assertLessEqual(
lazy_counts,
eager_counts,
f"Lazy reclaim mode should reduce event queries. "
f"Eager: {eager_counts}, Lazy: {lazy_counts}. "
f"If lazy >= eager, the optimization may not be working."
)
if __name__ == '__main__':
run_tests()