import json
import sys
import os
import threading
from typing import Optional, Dict, Any, List
from .logger import logger
from dataclasses import dataclass
from dataclasses import field
@dataclass
class FunctionContext:
local_values: dict = field(default_factory=dict)
return_value: Any = None
def check_profiling_enabled() -> bool:
"""检查是否启用了性能分析。
通过检查环境变量 SERVICE_PROF_CONFIG_PATH 来判断。
Returns:
bool: 如果启用了性能分析则返回True,否则返回False
"""
if not os.environ.get('SERVICE_PROF_CONFIG_PATH'):
logger.debug("SERVICE_PROF_CONFIG_PATH not set, skipping hooks")
return False
return True
def load_yaml_config(config_path: str) -> Optional[List[Dict[str, Any]]]:
"""加载 YAML 配置文件。
Args:
config_path: 配置文件路径
Returns:
Optional[List[Dict[str, Any]]]: 配置数据列表,失败时返回 None
Raises:
ImportError: 当 PyYAML 未安装时
FileNotFoundError: 当配置文件不存在时
"""
try:
import yaml
except ImportError:
logger.error("PyYAML is required for configuration loading")
return None
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
if config is None:
return None
if isinstance(config, list):
return config
logger.warning("Configuration file should be a list of hook configurations")
return []
except FileNotFoundError:
logger.warning(f"Configuration file does not exist: {config_path}")
return None
except Exception as e:
logger.error(f"Failed to load YAML configuration: {e}")
return None
def parse_version_tuple(version_str: str) -> tuple:
"""解析版本字符串为元组。
将版本字符串解析为 (major, minor, patch) 格式的元组。
处理包含 "+" 或 "-" 的版本字符串,只取主要版本号部分。
Args:
version_str: 版本字符串,如 "1.2.3+dev" 或 "0.9.2"
Returns:
tuple: (major, minor, patch) 版本元组
Example:
>>> parse_version_tuple("1.2.3+dev")
(1, 2, 3)
>>> parse_version_tuple("0.9")
(0, 9, 0)
"""
if not isinstance(version_str, str):
return (0, 0, 0)
parts = version_str.split("+")[0].split("-")[0].split(".")
nums = []
for p in parts:
try:
nums.append(int(p))
except ValueError:
break
while len(nums) < 3:
nums.append(0)
return tuple(nums[:3])
def get_package_version(package_name: str) -> Optional[str]:
"""获取已安装包的版本号。
优先使用 Python 3.8+ 内置的 importlib.metadata,
如果失败则尝试从包的 __version__ 属性获取。
Args:
package_name: 包名,如 "vllm", "sglang"
Returns:
Optional[str]: 版本号,如果包未安装则返回 None
"""
try:
from importlib.metadata import version
return version(package_name)
except ImportError:
pass
except Exception:
pass
try:
import importlib
module = importlib.import_module(package_name)
return getattr(module, "__version__", None)
except Exception:
return None
class SharedHookState:
"""共享的 hook 状态类。"""
def __init__(self):
"""初始化 SharedHookState。"""
self.request_id_to_prompt_token_len: Dict[str, int] = {}
self.request_id_to_iter: Dict[str, int] = {}
self._lock = threading.RLock()
_GLOBAL_SHARED_STATE = None
_GLOBAL_STATE_LOCK = threading.Lock()
def get_shared_state() -> SharedHookState:
"""获取全局共享的 SharedHookState 实例(线程安全)。"""
global _GLOBAL_SHARED_STATE
if _GLOBAL_SHARED_STATE is None:
with _GLOBAL_STATE_LOCK:
if _GLOBAL_SHARED_STATE is None:
_GLOBAL_SHARED_STATE = SharedHookState()
return _GLOBAL_SHARED_STATE
def install_symbol_watcher(watcher) -> bool:
"""安装符号观察器。
尝试使用 ms_service_metric.core.symbol_watcher.SymbolWatcher 进行符号监控,
如果不可用则回退到使用 sys.meta_path.insert 方式。
Args:
watcher: SymbolWatchFinder 实例,用于处理模块加载事件
Returns:
bool: 是否成功安装 SymbolWatcher(True表示使用了SymbolWatcher,False表示回退到meta_path)
"""
try:
from ms_service_metric.core.module.symbol_watcher import SymbolWatcher
except ImportError:
sys.meta_path.insert(0, watcher)
logger.debug("Symbol watcher installed via sys.meta_path")
return False
symbol_watcher = SymbolWatcher()
symbol_watcher.start()
symbol_watcher.watch(lambda module_name: watcher.on_symbol_module_loaded(module_name))
logger.debug("Symbol watcher installed via SymbolWatcher")
return True