"""
C++ Service Profiler 库的 Python 绑定。
该模块负责:
- 加载 C++ 库 (libms_service_profiler.so)
- 封装 C++ 函数调用
- 统一管理 Profiler 回调(C++ 调用 mstx,mstx 再调用各 Profiler)
回调架构:
C++ StartProfiler → mstx._on_cpp_start() → 遍历调用所有注册的 start 回调
C++ StopProfiler → mstx._on_cpp_stop() → 遍历调用所有注册的 stop 回调
"""
import ctypes
from typing import Callable, List
from ms_service_profiler.utils.file_open_check import get_valid_lib_path
class ProfilerCallbackResult:
"""回调注册结果。"""
DYNAMIC = "dynamic"
LEGACY = "legacy"
def __init__(self, mode: str, message: str = ""):
self.mode = mode
self.message = message
@property
def is_dynamic(self) -> bool:
return self.mode == self.DYNAMIC
@property
def is_legacy(self) -> bool:
return self.mode == self.LEGACY
class LibServiceProfiler:
"""C++ Service Profiler 库的 Python 封装。"""
lib_service_profiler = None
def __init__(self) -> None:
self.is_initialized = False
self.lib = None
self.func_start_span_with_name = None
self.func_end_span = None
self.func_span_end_ex = None
self.func_mark_span_attr = None
self.func_mark_event = None
self.func_mark_event_ex = None
self.func_start_service_profiler = None
self.func_stop_service_profiler = None
self.func_is_enable = None
self.func_is_valid_dommain = None
self.func_add_meta_info = None
self.func_get_prof_path = None
self.func_get_acl_task_time_level = None
self.func_get_acl_prof_aicore_metrics = None
self.func_get_torch_prof_step_num = None
self.func_get_torch_prof_stack = None
self.func_get_torch_prof_modules = None
self.func_get_torch_profiler_enable = None
self._func_register_start_callback = None
self._func_register_stop_callback = None
self._func_register_start_metric_callback = None
self._func_register_stop_metric_callback = None
self._start_callbacks: List[Callable[[], None]] = []
self._stop_callbacks: List[Callable[[], None]] = []
self._start_metric_callbacks: List[Callable[[], None]] = []
self._stop_metric_callbacks: List[Callable[[], None]] = []
self._c_callback_refs = []
self.func_set_profiler_current_step = None
self._cpp_callbacks_registered = False
self._cpp_metric_callbacks_registered = False
def init(self) -> None:
"""初始化 C++ 库。"""
if self.is_initialized:
return
self.is_initialized = True
so_name = "libms_service_profiler.so"
fp = get_valid_lib_path(so_name)
if not fp:
self.lib = None
return
try:
self.lib = ctypes.cdll.LoadLibrary(fp)
except Exception:
self.lib = None
return
if self.lib is not None:
self._init_basic_funcs()
self._init_config_funcs()
self._init_callback_funcs()
def _init_basic_funcs(self):
"""初始化基础函数。"""
self.func_start_span_with_name = self.lib.StartSpanWithName
self.func_start_span_with_name.argtypes = (ctypes.c_char_p,)
self.func_start_span_with_name.restype = ctypes.c_ulonglong
self.func_end_span = self.lib.EndSpan
self.func_end_span.argtypes = (ctypes.c_ulonglong,)
self.func_mark_span_attr = self.lib.MarkSpanAttr
self.func_mark_span_attr.argtypes = (ctypes.c_char_p, ctypes.c_ulonglong)
self.func_mark_event = self.lib.MarkEvent
self.func_mark_event.argtypes = (ctypes.c_char_p,)
try:
self.func_mark_event_ex = self.lib.MarkEventEx
self.func_mark_event_ex.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p)
except AttributeError:
self.func_mark_event_ex = None
try:
self.func_span_end_ex = self.lib.SpanEndEx
self.func_span_end_ex.argtypes = (ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_ulonglong)
except AttributeError:
self.func_span_end_ex = None
self.func_start_service_profiler = self.lib.StartServerProfiler
self.func_stop_service_profiler = self.lib.StopServerProfiler
self.func_is_enable = self.lib.IsEnable
self.func_is_enable.argtypes = (ctypes.c_ulong,)
self.func_is_enable.restype = ctypes.c_bool
def _init_config_funcs(self):
"""初始化配置相关函数。"""
if hasattr(self.lib, "IsValidDomain"):
self.func_is_valid_dommain = self.lib.IsValidDomain
self.func_is_valid_dommain.argtypes = (ctypes.c_char_p,)
self.func_is_valid_dommain.restype = ctypes.c_bool
if hasattr(self.lib, "AddMetaInfo"):
self.func_add_meta_info = self.lib.AddMetaInfo
self.func_add_meta_info.argtypes = (ctypes.c_char_p, ctypes.c_char_p)
if hasattr(self.lib, "GetProfPath"):
self.func_get_prof_path = self.lib.GetProfPath
self.func_get_prof_path.restype = ctypes.c_char_p
if hasattr(self.lib, "GetAclTaskTimeLevel"):
self.func_get_acl_task_time_level = self.lib.GetAclTaskTimeLevel
self.func_get_acl_task_time_level.restype = ctypes.c_char_p
if hasattr(self.lib, "GetAclProfAicoreMetrics"):
self.func_get_acl_prof_aicore_metrics = self.lib.GetAclProfAicoreMetrics
self.func_get_acl_prof_aicore_metrics.restype = ctypes.c_int
if hasattr(self.lib, "GetTorchProfStepNum"):
self.func_get_torch_prof_step_num = self.lib.GetTorchProfStepNum
self.func_get_torch_prof_step_num.restype = ctypes.c_int
if hasattr(self.lib, "GetTorchProfStack"):
self.func_get_torch_prof_stack = self.lib.GetTorchProfStack
self.func_get_torch_prof_stack.restype = ctypes.c_bool
if hasattr(self.lib, "GetTorchProfModules"):
self.func_get_torch_prof_modules = self.lib.GetTorchProfModules
self.func_get_torch_prof_modules.restype = ctypes.c_bool
if hasattr(self.lib, "GetTorchProfilerEnable"):
self.func_get_torch_profiler_enable = self.lib.GetTorchProfilerEnable
self.func_get_torch_profiler_enable.restype = ctypes.c_bool
if hasattr(self.lib, "SetProfilerCurrentStep"):
self.func_set_profiler_current_step = self.lib.SetProfilerCurrentStep
self.func_set_profiler_current_step.argtypes = (ctypes.c_int,)
self.func_set_profiler_current_step.restype = None
def _init_callback_funcs(self):
"""初始化 C++ 回调注册函数。"""
if hasattr(self.lib, "RegisterProfilerStartCallback"):
self._func_register_start_callback = self.lib.RegisterProfilerStartCallback
self._func_register_start_callback.argtypes = (ctypes.CFUNCTYPE(None),)
if hasattr(self.lib, "RegisterProfilerStopCallback"):
self._func_register_stop_callback = self.lib.RegisterProfilerStopCallback
self._func_register_stop_callback.argtypes = (ctypes.CFUNCTYPE(None),)
if hasattr(self.lib, "RegisterProfilerStartMetricCallback"):
self._func_register_start_metric_callback = self.lib.RegisterProfilerStartMetricCallback
self._func_register_start_metric_callback.argtypes = (ctypes.CFUNCTYPE(None),)
if hasattr(self.lib, "RegisterProfilerStopMetricCallback"):
self._func_register_stop_metric_callback = self.lib.RegisterProfilerStopMetricCallback
self._func_register_stop_metric_callback.argtypes = (ctypes.CFUNCTYPE(None),)
def _on_cpp_start(self) -> None:
"""C++ StartProfiler 时调用,遍历调用所有注册的 start 回调。"""
for callback in self._start_callbacks:
try:
callback()
except Exception:
pass
def _on_cpp_stop(self) -> None:
"""C++ StopProfiler 时调用,遍历调用所有注册的 stop 回调。"""
for callback in self._stop_callbacks:
try:
callback()
except Exception:
pass
def _on_cpp_start_metric(self) -> None:
"""C++ startMetricCallback 时调用。"""
for callback in self._start_metric_callbacks:
try:
callback()
except Exception:
pass
def _on_cpp_stop_metric(self) -> None:
"""C++ stopMetricCallback 时调用。"""
for callback in self._stop_metric_callbacks:
try:
callback()
except Exception:
pass
def _ensure_cpp_callbacks_registered(self) -> bool:
"""确保 mstx 的回调已注册到 C++(只注册一次)。
Returns:
bool: 是否支持动态回调
"""
if self._cpp_callbacks_registered:
return True
self.init()
if self.lib is None:
return False
if self._func_register_start_callback is None or self._func_register_stop_callback is None:
return False
c_start = ctypes.CFUNCTYPE(None)(self._on_cpp_start)
c_stop = ctypes.CFUNCTYPE(None)(self._on_cpp_stop)
self._c_callback_refs.append(c_start)
self._c_callback_refs.append(c_stop)
self._func_register_start_callback(c_start)
self._func_register_stop_callback(c_stop)
self._cpp_callbacks_registered = True
return True
def _ensure_cpp_metric_callbacks_registered(self) -> bool:
"""确保 metric 回调已注册到 C++(只注册一次)。"""
if self._cpp_metric_callbacks_registered:
return True
self.init()
if self.lib is None:
return False
if self._func_register_start_metric_callback is None or self._func_register_stop_metric_callback is None:
return False
c_start_metric = ctypes.CFUNCTYPE(None)(self._on_cpp_start_metric)
c_stop_metric = ctypes.CFUNCTYPE(None)(self._on_cpp_stop_metric)
self._c_callback_refs.append(c_start_metric)
self._c_callback_refs.append(c_stop_metric)
self._func_register_start_metric_callback(c_start_metric)
self._func_register_stop_metric_callback(c_stop_metric)
self._cpp_metric_callbacks_registered = True
return True
def register_profiler_start_callback(self, callback: Callable[[], None]) -> ProfilerCallbackResult:
"""注册 Profiler 启动回调。
Args:
callback: Profiler 启动时的回调函数
Returns:
ProfilerCallbackResult: 注册结果
"""
self._start_callbacks.append(callback)
if self._ensure_cpp_callbacks_registered():
return ProfilerCallbackResult(
ProfilerCallbackResult.DYNAMIC,
"Callback registered successfully"
)
else:
return ProfilerCallbackResult(
ProfilerCallbackResult.LEGACY,
"C++ library does not support dynamic callbacks"
)
def register_profiler_stop_callback(self, callback: Callable[[], None]) -> ProfilerCallbackResult:
"""注册 Profiler 停止回调。
Args:
callback: Profiler 停止时的回调函数
Returns:
ProfilerCallbackResult: 注册结果
"""
self._stop_callbacks.append(callback)
if self._ensure_cpp_callbacks_registered():
return ProfilerCallbackResult(
ProfilerCallbackResult.DYNAMIC,
"Callback registered successfully"
)
else:
return ProfilerCallbackResult(
ProfilerCallbackResult.LEGACY,
"C++ library does not support dynamic callbacks"
)
def register_profiler_start_metric_callback(self, callback: Callable[[], None]) -> ProfilerCallbackResult:
"""注册 metric 采集启动回调(metric_enable 0→1 时由 C++ DynamicControl 调用)。"""
self._start_metric_callbacks.append(callback)
if self._ensure_cpp_metric_callbacks_registered():
return ProfilerCallbackResult(
ProfilerCallbackResult.DYNAMIC,
"Metric start callback registered successfully"
)
return ProfilerCallbackResult(
ProfilerCallbackResult.LEGACY,
"C++ library does not support metric callbacks"
)
def register_profiler_stop_metric_callback(self, callback: Callable[[], None]) -> ProfilerCallbackResult:
"""注册 metric 采集停止回调(metric_enable 1→0 时由 C++ DynamicControl 调用)。"""
self._stop_metric_callbacks.append(callback)
if self._ensure_cpp_metric_callbacks_registered():
return ProfilerCallbackResult(
ProfilerCallbackResult.DYNAMIC,
"Metric stop callback registered successfully"
)
return ProfilerCallbackResult(
ProfilerCallbackResult.LEGACY,
"C++ library does not support metric callbacks"
)
def supports_dynamic_callbacks(self) -> bool:
"""检查是否支持动态回调。
Returns:
bool: 是否支持
"""
self.init()
return (self.lib is not None and
self._func_register_start_callback is not None and
self._func_register_stop_callback is not None)
def start_span(self, name=None):
self.init()
if self.func_start_span_with_name is None:
return 0
msg = "" if name is None else name
return self.func_start_span_with_name(bytes(msg, encoding="utf-8"))
def end_span(self, span_handle):
self.init()
if self.func_end_span is not None:
self.func_end_span(span_handle)
def mark_span_attr(self, msg, span_handle):
self.init()
if self.func_mark_span_attr is not None:
self.func_mark_span_attr(bytes(msg, encoding="utf-8"), span_handle)
def mark_event(self, msg):
self.init()
if self.func_mark_event is not None:
self.func_mark_event(bytes(msg, encoding="utf-8"))
def mark_event_ex(self, name: str, domain: str, msg: str):
"""记录增强事件。"""
self.init()
if self.func_mark_event_ex is not None:
self.func_mark_event_ex(
bytes(name, "utf-8"),
bytes(domain, "utf-8"),
bytes(msg, "utf-8")
)
elif self.func_mark_event is not None:
import json
legacy = json.dumps({"name": name, "domain": domain, "msg": msg}, ensure_ascii=False)
self.func_mark_event(bytes(legacy, "utf-8"))
def span_end_ex(self, name: str, domain: str, msg: str, span_handle: int):
self.init()
if self.func_span_end_ex is not None:
self.func_span_end_ex(
bytes(name, "utf-8"),
bytes(domain, "utf-8"),
bytes(msg, "utf-8"),
span_handle
)
elif self.func_end_span is not None:
if self.func_mark_span_attr is not None:
import json
extra = json.dumps({
"name": name,
"domain": domain,
"msg": msg
}, ensure_ascii=False)
self.func_mark_span_attr(bytes(extra, "utf-8"), span_handle)
self.func_end_span(span_handle)
def start_profiler(self):
self.init()
if self.func_start_service_profiler is not None:
self.func_start_service_profiler()
def stop_profiler(self):
self.init()
if self.func_stop_service_profiler is not None:
self.func_stop_service_profiler()
def is_enable(self, profiler_level):
self.init()
if self.func_is_enable is None:
return False
return self.func_is_enable(profiler_level)
def is_domain_enable(self, domain_name):
self.init()
if self.func_is_valid_dommain is None:
return True
return self.func_is_valid_dommain(bytes(domain_name, encoding="utf-8"))
def add_meta_info(self, key, value):
self.init()
if self.func_add_meta_info is not None:
self.func_add_meta_info(bytes(key, encoding="utf-8"), bytes(value, encoding="utf-8"))
def get_prof_path(self):
self.init()
if self.func_get_prof_path is None:
return ""
result = self.func_get_prof_path()
if result:
return result.decode("utf-8")
return ""
def is_torch_profiler_enable(self, profiler_level):
self.init()
if self.func_get_torch_profiler_enable is None or self.func_is_enable is None:
return False
return self.func_get_torch_profiler_enable() and self.func_is_enable(profiler_level)
def is_torch_profiler_register(self):
self.init()
if self.func_get_torch_profiler_enable is None:
return False
return self.func_get_torch_profiler_enable()
def get_acl_task_time_level(self):
self.init()
if self.func_get_acl_task_time_level is None:
return "L0"
result = self.func_get_acl_task_time_level()
if result:
return result.decode("utf-8")
return "L0"
def get_acl_prof_aicore_metrics(self):
self.init()
if self.func_get_acl_prof_aicore_metrics is None:
return -1
return self.func_get_acl_prof_aicore_metrics()
def get_torch_prof_step_num(self):
self.init()
if self.func_get_torch_prof_step_num is None:
return 0
return self.func_get_torch_prof_step_num()
def is_torch_prof_stack(self):
self.init()
if self.func_get_torch_prof_stack is None:
return False
return self.func_get_torch_prof_stack()
def is_torch_prof_modules(self):
self.init()
if self.func_get_torch_prof_modules is None:
return False
return self.func_get_torch_prof_modules()
def set_profiler_current_step(self, step: int) -> None:
"""设置当前 Profiler 步数,用于触发 step-based 自动停止。
Args:
step (int): 当前训练/推理步数
"""
self.init()
if self.func_set_profiler_current_step is not None:
self.func_set_profiler_current_step(ctypes.c_int(step))
service_profiler = LibServiceProfiler()