"""
将训练日志解析为 Chrome Trace 格式,以便可视化关键阶段耗时。
用法示例:
python perf.py --log test.log --config perf.yaml
python perf.py --log log1.log log2.log --config perf.yaml
"""
from __future__ import annotations
import argparse
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple
from datetime import datetime
try:
import yaml
except ModuleNotFoundError as exc:
raise SystemExit("缺少 PyYAML 依赖,请执行 pip install pyyaml") from exc
LOGGER = logging.getLogger("perf")
@dataclass
class EventRule:
name: str
start_pattern: re.Pattern
end_pattern: re.Pattern
time_group: int
id_group: Optional[int]
pid: Optional[int]
tid: Optional[int]
def load_config(path: Path) -> Dict:
with path.open("r", encoding="utf-8") as fh:
config = yaml.safe_load(fh)
if "trace" not in config:
raise ValueError("配置文件缺少 trace 节点")
return config["trace"]
def compile_rules(trace_cfg: Dict) -> List[EventRule]:
events_cfg = trace_cfg.get("events", [])
if not events_cfg:
raise ValueError("trace.events 不能为空")
compiled: List[EventRule] = []
for raw in events_cfg:
try:
name = raw["name"]
start_regex = re.compile(raw["start_pattern"])
end_regex = re.compile(raw["end_pattern"])
time_group = int(raw.get("time_group", 1))
except KeyError as exc:
raise ValueError(f"事件配置缺少字段: {exc}") from exc
id_group = raw.get("id_group")
compiled.append(
EventRule(
name=name,
start_pattern=start_regex,
end_pattern=end_regex,
time_group=time_group,
id_group=int(id_group) if id_group is not None else None,
pid=raw.get("pid"),
tid=raw.get("tid"),
)
)
return compiled
def _convert_timestamp(ts_str: str, unit: str, dt_format: Optional[str]) -> float:
if unit == "datetime":
if not dt_format:
raise ValueError("timestamp_unit 为 datetime 时必须提供 timestamp_format")
dt = datetime.strptime(ts_str, dt_format)
return dt.timestamp() * 1_000_000.0
value = float(ts_str)
if unit == "unix_ms":
return value * 1000.0
if unit == "unix_s":
return value * 1_000_000.0
if unit == "relative_ms":
return value * 1000.0
raise ValueError(f"不支持的时间单位: {unit}")
def build_trace_from_log(
log_path: Path,
rules: Iterable[EventRule],
pid_default: int,
tid_default: int,
timestamp_unit: str,
timestamp_format: Optional[str],
) -> List[Dict]:
trace_events: List[Dict] = []
active_with_id: Dict[Tuple[str, str], Dict] = {}
active_no_id: Dict[str, List[Dict]] = {}
auto_id_counters: Dict[str, int] = {}
with log_path.open("r", encoding="utf-8") as fh:
for line_no, line in enumerate(fh, 1):
text = line.rstrip("\n")
matched = False
for rule in rules:
start_match = rule.start_pattern.search(text)
if start_match:
matched = True
if rule.id_group is not None:
event_id = start_match.group(rule.id_group)
key = (rule.name, event_id)
if key in active_with_id:
LOGGER.warning(
"Line %d: %s 已有未结束事件(id=%s),将被覆盖",
line_no,
rule.name,
event_id,
)
target_store = active_with_id
else:
counter = auto_id_counters.get(rule.name, 0) + 1
auto_id_counters[rule.name] = counter
event_id = str(counter)
target_store = active_no_id.setdefault(rule.name, [])
ts = _convert_timestamp(
start_match.group(rule.time_group),
timestamp_unit,
timestamp_format,
)
record = {"ts": ts, "line": line_no, "id": event_id}
if rule.id_group is not None:
target_store[(rule.name, event_id)] = record
else:
target_store.append(record)
continue
end_match = rule.end_pattern.search(text)
if end_match:
matched = True
if rule.id_group is not None:
event_id = end_match.group(rule.id_group)
key = (rule.name, event_id)
start_info = active_with_id.pop(key, None)
else:
queue = active_no_id.get(rule.name)
start_info = queue.pop(0) if queue else None
if queue is not None and not queue:
active_no_id.pop(rule.name, None)
event_id = start_info["id"] if start_info else None
if not start_info:
LOGGER.warning(
"Line %d: 找到 %s 结束但没有匹配的开始(id=%s)",
line_no,
rule.name,
event_id,
)
continue
end_ts = _convert_timestamp(
end_match.group(rule.time_group),
timestamp_unit,
timestamp_format,
)
duration = max(0.0, end_ts - start_info["ts"])
trace_events.append(
{
"name": rule.name,
"cat": rule.name,
"ph": "X",
"ts": start_info["ts"],
"dur": duration,
"pid": rule.pid or pid_default,
"tid": rule.tid or tid_default,
"args": {
"id": event_id,
"start_line": start_info["line"],
"end_line": line_no,
},
}
)
continue
if not matched:
LOGGER.debug("Line %d: 无匹配事件 -> %s", line_no, text)
for (name, event_id), info in active_with_id.items():
LOGGER.warning("事件 %s (id=%s) 在行 %d 开始但未找到结束", name, event_id, info["line"])
for name, queue in active_no_id.items():
for info in queue:
LOGGER.warning(
"事件 %s (auto_id=%s) 在行 %d 开始但未找到结束",
name,
info["id"],
info["line"],
)
return trace_events
def save_trace(output: Path, trace_events: List[Dict]) -> None:
output.parent.mkdir(parents=True, exist_ok=True)
payload = {"traceEvents": trace_events}
with output.open("w", encoding="utf-8") as fh:
json.dump(payload, fh, ensure_ascii=False, indent=2)
LOGGER.info("已写入 trace 文件:%s", output)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="训练日志转 trace 工具")
parser.add_argument(
"--log",
type=Path,
nargs="+",
required=True,
help="训练日志路径(可指定多个文件)",
)
parser.add_argument("--config", type=Path, default=Path("perf.yaml"), help="配置文件路径")
parser.add_argument("--output", type=Path, help="覆盖配置里的 trace.output,若提供则优先生效")
parser.add_argument("--verbose", action="store_true", help="打印更多调试信息(用于排查匹配问题)")
return parser.parse_args()
def main() -> None:
args = parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(levelname)s: %(message)s",
)
trace_cfg = load_config(args.config)
rules = compile_rules(trace_cfg)
pid_default = int(trace_cfg.get("pid", 0))
tid_default = int(trace_cfg.get("tid", 0))
timestamp_unit = trace_cfg.get("timestamp_unit", "unix_ms")
timestamp_format = trace_cfg.get("timestamp_format")
all_events: List[Dict] = []
for log_path in args.log:
if not log_path.exists():
LOGGER.warning("日志文件不存在,跳过: %s", log_path)
continue
LOGGER.info("解析日志 %s", log_path)
events = build_trace_from_log(
log_path=log_path,
rules=rules,
pid_default=pid_default,
tid_default=tid_default,
timestamp_unit=timestamp_unit,
timestamp_format=timestamp_format,
)
all_events.extend(events)
LOGGER.info("从 %s 提取了 %d 个事件", log_path, len(events))
LOGGER.info("总共提取了 %d 个事件", len(all_events))
output_path = args.output or Path(trace_cfg.get("output", "trace.json"))
save_trace(output_path, all_events)
if __name__ == "__main__":
main()