"""Profiler context manager for NPU profiling using torch_npu.profiler.
This module provides a context manager for profiling NPU operations during
inference. When profiling is disabled, the context manager becomes a no-op.
"""
import os
import torch_npu
class FakeContextManager:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
@staticmethod
def step():
return
@staticmethod
def stop():
return
def create_profiler(enable_profiler=False, profile_save_path="prof", active=3, repeat=1, skip_first=3):
if enable_profiler:
os.makedirs(profile_save_path, exist_ok=True)
experimental_config = torch_npu.profiler._ExperimentalConfig(
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization
)
profiler = torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.NPU,
torch_npu.profiler.ProfilerActivity.CPU,
],
with_stack=False,
record_shapes=False,
profile_memory=False,
experimental_config=experimental_config,
schedule=torch_npu.profiler.schedule(
wait=0,
warmup=0,
active=active,
repeat=repeat,
skip_first=skip_first
),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path)
)
else:
profiler = FakeContextManager()
return profiler
class ProfilerManager:
"""
Profiler manager for profiling using torch_npu.profiler.
This class provides a manager for profiling NPU operations during inference.
"""
def __init__(self, enable_profiler, profile_save_path):
self.enable_profiler = enable_profiler
self.profile_save_path = profile_save_path
self.status = None
self.current_profiler = FakeContextManager()
def set_status(self, is_prefill):
if not self.enable_profiler:
return
if self.status is None and is_prefill:
self.status = "prefill"
self.current_profiler = create_profiler(
enable_profiler=self.enable_profiler,
profile_save_path=os.path.join(self.profile_save_path, "prof", "prefill"),
active=1, repeat=1, skip_first=0)
self.current_profiler.start()
elif self.status == "prefill" and not is_prefill:
self.status = "decode"
self.current_profiler.stop()
self.current_profiler = create_profiler(
enable_profiler=self.enable_profiler,
profile_save_path=os.path.join(self.profile_save_path, "prof", "decode"))
self.current_profiler.start()
return
def step(self):
self.current_profiler.step()
def __del__(self):
if self.enable_profiler:
self.current_profiler.stop()