"""Public API: ``@record_stage`` decorator and ``span`` context manager."""
from __future__ import annotations
import functools
import logging
import time
from contextlib import contextmanager
from typing import Any, Callable, Iterator
from uuid import uuid4
from perf import span_stack
from perf.recorder import SpanEvent, clocks, get_recorder
logger = logging.getLogger("ogmem.perf.decorators")
def _ids_from_params(params: Any) -> tuple[str | None, str | None]:
"""Pull ``session_id`` / ``trace_id`` out of a bridge ``params`` dict."""
if not isinstance(params, dict):
return None, None
sid = params.get("sessionId") or params.get("session_id")
tid = params.get("traceId") or params.get("trace_id")
return (sid or None), (tid or None)
def _ids_from_kwargs(args: tuple, kwargs: dict) -> tuple[str | None, str | None]:
"""Recover IDs from a call's args/kwargs. ``args[0]`` is ``self`` on methods.
The bridge convention is ``method(self, params)`` so we check ``args[1]``.
"""
for candidate in list(args[1:]) + list(kwargs.values()):
if isinstance(candidate, dict):
sid, tid = _ids_from_params(candidate)
if sid or tid:
return sid, tid
return None, None
def record_stage(stage: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorate a ``MemoryService`` lifecycle method.
Emits a single ``SpanEvent`` (the stage root) summarising wall time,
CPU time, token deltas, $-cost, and any sub-span metadata accumulated
during execution.
Errors are recorded but **re-raised** — per CLAUDE.md §工程原则
"禁止 fallback 路径".
"""
def _decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def _wrapper(*args: Any, **kwargs: Any) -> Any:
rec = get_recorder()
if not rec.enabled:
return fn(*args, **kwargs)
started_at, wall_start, cpu_start = clocks()
sid, tid = _ids_from_kwargs(args, kwargs)
frame = span_stack.SpanFrame(
stage=stage,
span="",
parent_span=None,
started_at=started_at,
wall_start=wall_start,
cpu_start=cpu_start,
session_id=sid,
trace_id=tid,
run_id=rec.run_id,
token_snapshot=rec.snapshot_tokens(),
)
span_stack.push(frame)
ok = True
error: str | None = None
try:
return fn(*args, **kwargs)
except Exception as exc:
ok = False
error = f"{type(exc).__name__}: {exc}"
raise
finally:
popped = span_stack.pop()
after = rec.snapshot_tokens()
tokens, cost, source = rec.finalize_tokens(
after, popped.token_snapshot if popped else {}
)
wall_ms = (time.perf_counter() - (popped.wall_start if popped else wall_start)) * 1000.0
cpu_ms = (time.process_time() - (popped.cpu_start if popped else cpu_start)) * 1000.0
event = SpanEvent(
run_id=rec.run_id,
session_id=(popped.session_id if popped else sid),
trace_id=(popped.trace_id if popped else tid),
stage=stage,
span="",
parent_span=None,
started_at=popped.started_at if popped else started_at,
wall_ms=round(wall_ms, 3),
cpu_ms=round(cpu_ms, 3),
ok=ok,
error=error,
llm_model=tokens.get("llm_model"),
embed_model=tokens.get("embed_model"),
tokens={
"llm": tokens.get("llm", {}),
"embed": tokens.get("embed", {}),
},
cost_usd=cost,
token_source=source,
meta=dict(popped.meta) if popped else {},
)
try:
rec.emit(event)
except Exception as exc:
logger.error("emit failed: %s", exc, exc_info=True)
if ok:
raise
return _wrapper
return _decorator
@contextmanager
def span(name: str, **meta: Any) -> Iterator[span_stack.SpanFrame]:
"""Context manager for a named sub-span inside a decorated stage.
Usage:
with span("extract_llm", num_messages=len(msgs)) as s:
candidates = extractor.extract(...)
s.meta["num_candidates"] = len(candidates)
If the recorder is disabled, this is a zero-cost no-op.
"""
rec = get_recorder()
if not rec.enabled:
_dummy = span_stack.SpanFrame(
stage="(disabled)", span=name, parent_span=None,
started_at=0.0, wall_start=0.0, cpu_start=0.0,
)
_dummy.meta.update(meta)
yield _dummy
return
parent = span_stack.top()
stage_frame = span_stack.stage_root()
stage_name = stage_frame.stage if stage_frame else "(orphan)"
started_at, wall_start, cpu_start = clocks()
frame = span_stack.SpanFrame(
stage=stage_name,
span=name,
parent_span=(parent.span if parent else None) or "",
started_at=started_at,
wall_start=wall_start,
cpu_start=cpu_start,
session_id=parent.session_id if parent else None,
trace_id=parent.trace_id if parent else None,
run_id=rec.run_id,
token_snapshot=rec.snapshot_tokens(),
)
frame.meta.update(meta)
span_stack.push(frame)
ok = True
error: str | None = None
try:
yield frame
except Exception as exc:
ok = False
error = f"{type(exc).__name__}: {exc}"
raise
finally:
popped = span_stack.pop()
after = rec.snapshot_tokens()
tokens, cost, source = rec.finalize_tokens(
after, popped.token_snapshot if popped else {}
)
wall_ms = (time.perf_counter() - (popped.wall_start if popped else wall_start)) * 1000.0
cpu_ms = (time.process_time() - (popped.cpu_start if popped else cpu_start)) * 1000.0
event = SpanEvent(
run_id=rec.run_id,
session_id=(popped.session_id if popped else None),
trace_id=(popped.trace_id if popped else None),
stage=stage_name,
span=name,
parent_span=(popped.parent_span if popped else None) or None,
started_at=popped.started_at if popped else started_at,
wall_ms=round(wall_ms, 3),
cpu_ms=round(cpu_ms, 3),
ok=ok,
error=error,
llm_model=tokens.get("llm_model"),
embed_model=tokens.get("embed_model"),
tokens={
"llm": tokens.get("llm", {}),
"embed": tokens.get("embed", {}),
},
cost_usd=cost,
token_source=source,
meta=dict(popped.meta) if popped else dict(meta),
)
try:
rec.emit(event)
except Exception as exc:
logger.error("span emit failed: %s", exc, exc_info=True)
if ok:
raise