from __future__ import annotations
import contextvars
import functools
import inspect
import logging
import time
import uuid
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any
from slime.utils.types import Sample
TRACE_VERSION = 1
SGLANG_TRACE_META_KEYS = (
"prompt_tokens",
"completion_tokens",
"cached_tokens",
"pd_prefill_bootstrap_queue_duration",
"pd_prefill_forward_duration",
"pd_prefill_transfer_queue_duration",
"pd_prefill_retry_count",
"pd_decode_prealloc_duration",
"pd_decode_transfer_duration",
"pd_decode_forward_duration",
"pd_bootstrap_duration",
"pd_alloc_waiting_duration",
"pd_transfer_speed_gb_s",
"pd_transfer_total_mb",
)
logger = logging.getLogger(__name__)
_TRACE_STACK: contextvars.ContextVar[tuple[tuple[str, str], ...]] = contextvars.ContextVar(
"slime_trace_stack",
default=(),
)
_TRACE_HANDLE_STACK: contextvars.ContextVar[tuple[tuple[TraceHandle, ...], ...]] = contextvars.ContextVar(
"slime_trace_handle_stack",
default=(),
)
_TRACE_AUTO_INFER_WARNED: set[str] = set()
@dataclass
class TraceHandle:
trace_id: str
carrier: dict[str, Any]
sample_id: int | str | None = None
group_id: int | str | None = None
attempt: int = 0
parent_span_id: str | None = None
@dataclass
class TraceSpanContext:
target: Sample | TraceHandle | list[Sample | TraceHandle]
handles: list[TraceHandle]
end_attrs: dict[str, Any] = field(default_factory=dict)
end_events: list[dict[str, Any]] = field(default_factory=list)
closed: bool = False
def set(self, key: str, value: Any) -> TraceSpanContext:
self.end_attrs[key] = value
self._sync_end_events({key: value})
return self
def update(self, attrs: dict[str, Any] | None) -> TraceSpanContext:
if attrs:
self.end_attrs.update(attrs)
self._sync_end_events(attrs)
return self
def set_attr(self, key: str, value: Any) -> TraceSpanContext:
return self.set(key, value)
def update_attrs(self, attrs: dict[str, Any] | None) -> TraceSpanContext:
return self.update(attrs)
def build_end_attrs(self) -> dict[str, Any] | None:
return dict(self.end_attrs) or None
def finalize(self, end_events: list[dict[str, Any]]) -> None:
self.end_events = end_events
self.closed = True
if self.end_attrs:
self._sync_end_events(self.end_attrs)
def _sync_end_events(self, attrs: dict[str, Any]) -> None:
if not self.end_events or not attrs:
return
for event in self.end_events:
event.setdefault("attrs", {})
event["attrs"].update(attrs)
def _noop_handle() -> TraceHandle:
return TraceHandle(
trace_id="",
carrier={
"version": TRACE_VERSION,
"trace_id": "",
"events": [],
"sample_id": None,
"group_id": None,
"attempt": 0,
},
)
def _log_trace_error(action: str, exc: Exception) -> None:
logger.debug("trace %s skipped: %s", action, exc, exc_info=True)
def _new_trace_id() -> str:
return uuid.uuid4().hex
def _new_span_id() -> str:
return uuid.uuid4().hex
def build_sglang_meta_trace_attrs(meta: dict[str, Any]) -> dict[str, Any]:
attrs = {key: meta[key] for key in SGLANG_TRACE_META_KEYS if key in meta and meta[key] is not None}
attrs["finish_reason"] = meta["finish_reason"]["type"]
return attrs
def _ensure_trace_carrier(
carrier: dict[str, Any] | None,
*,
trace_id: str | None = None,
sample_id: int | str | None = None,
group_id: int | str | None = None,
attempt: int = 0,
) -> dict[str, Any]:
if carrier is None:
carrier = {}
carrier.setdefault("version", TRACE_VERSION)
carrier.setdefault("trace_id", trace_id or _new_trace_id())
carrier.setdefault("events", [])
if sample_id is not None:
carrier["sample_id"] = sample_id
else:
carrier.setdefault("sample_id", None)
if group_id is not None:
carrier["group_id"] = group_id
else:
carrier.setdefault("group_id", None)
carrier["attempt"] = int(carrier.get("attempt", attempt))
return carrier
def bind_trace(sample: Sample) -> TraceHandle:
try:
sample.trace = _ensure_trace_carrier(
getattr(sample, "trace", None),
sample_id=sample.index,
group_id=sample.group_index,
)
return TraceHandle(
trace_id=sample.trace["trace_id"],
carrier=sample.trace,
sample_id=sample.trace.get("sample_id"),
group_id=sample.trace.get("group_id"),
attempt=int(sample.trace.get("attempt", 0)),
)
except Exception as exc:
_log_trace_error("bind", exc)
return _noop_handle()
def bind_trace_carrier(
carrier: dict[str, Any] | None,
*,
trace_id: str | None = None,
sample_id: int | str | None = None,
group_id: int | str | None = None,
attempt: int = 0,
parent_span_id: str | None = None,
) -> TraceHandle:
try:
trace = _ensure_trace_carrier(
carrier,
trace_id=trace_id,
sample_id=sample_id,
group_id=group_id,
attempt=attempt,
)
return TraceHandle(
trace_id=trace["trace_id"],
carrier=trace,
sample_id=trace.get("sample_id"),
group_id=trace.get("group_id"),
attempt=int(trace.get("attempt", 0)),
parent_span_id=parent_span_id,
)
except Exception as exc:
_log_trace_error("bind_carrier", exc)
handle = _noop_handle()
handle.parent_span_id = parent_span_id
return handle
def export_trace(handle: TraceHandle) -> dict[str, Any]:
try:
return {
"version": TRACE_VERSION,
"trace_id": handle.trace_id,
"sample_id": handle.sample_id,
"group_id": handle.group_id,
"attempt": handle.attempt,
"parent_span_id": handle.parent_span_id or _get_current_parent_span_id(handle.trace_id),
}
except Exception as exc:
_log_trace_error("export", exc)
return {
"version": TRACE_VERSION,
"trace_id": "",
"sample_id": None,
"group_id": None,
"attempt": 0,
"parent_span_id": None,
}
def import_trace(payload: dict[str, Any], carrier: dict[str, Any] | None = None) -> TraceHandle:
try:
return bind_trace_carrier(
carrier,
trace_id=payload.get("trace_id"),
sample_id=payload.get("sample_id"),
group_id=payload.get("group_id"),
attempt=int(payload.get("attempt", 0)),
parent_span_id=payload.get("parent_span_id"),
)
except Exception as exc:
_log_trace_error("import", exc)
return _noop_handle()
def trace_event(
target: Sample | TraceHandle | list[Sample | TraceHandle], name: str, *, attrs: dict[str, Any] | None = None
):
try:
timestamp = time.time()
for handle in _coerce_handles(target):
_append_event(handle, kind="event", name=name, timestamp=timestamp, attrs=attrs)
except Exception as exc:
_log_trace_error(f"event:{name}", exc)
@contextmanager
def trace_span(
target: Sample | TraceHandle | list[Sample | TraceHandle],
name: str,
*,
attrs: dict[str, Any] | None = None,
record_error: bool = True,
):
try:
handles = _coerce_handles(target)
except Exception as exc:
_log_trace_error(f"span:{name}", exc)
handles = []
if not handles:
yield target
return
timestamp = time.time()
stack_before = _TRACE_STACK.get()
handle_stack_before = _TRACE_HANDLE_STACK.get()
span_records: list[tuple[TraceHandle, str]] = []
new_entries: list[tuple[str, str]] = []
try:
for handle in handles:
span_id = _new_span_id()
parent_span_id = handle.parent_span_id or _get_current_parent_span_id(handle.trace_id, stack=stack_before)
_append_event(
handle,
kind="span_start",
name=name,
timestamp=timestamp,
span_id=span_id,
parent_span_id=parent_span_id,
attrs=attrs,
)
span_records.append((handle, span_id))
new_entries.append((handle.trace_id, span_id))
token = _TRACE_STACK.set(stack_before + tuple(new_entries))
handle_token = _TRACE_HANDLE_STACK.set(handle_stack_before + (tuple(handles),))
except Exception as exc:
_log_trace_error(f"span:{name}", exc)
yield target
return
span_context = TraceSpanContext(
target=handles[0] if len(handles) == 1 else handles,
handles=handles,
)
try:
yield span_context
except Exception as exc:
try:
end_attrs = span_context.build_end_attrs()
if record_error:
error_attrs = {"error_type": type(exc).__name__, "error_message": str(exc)}
if end_attrs:
end_attrs.update(error_attrs)
else:
end_attrs = error_attrs
span_context.finalize(_record_span_end(span_records, name=name, attrs=end_attrs))
except Exception as trace_exc:
_log_trace_error(f"span_end:{name}", trace_exc)
raise
else:
try:
span_context.finalize(_record_span_end(span_records, name=name, attrs=span_context.build_end_attrs()))
except Exception as exc:
_log_trace_error(f"span_end:{name}", exc)
finally:
try:
_TRACE_STACK.reset(token)
except Exception as exc:
_log_trace_error(f"span_reset:{name}", exc)
try:
_TRACE_HANDLE_STACK.reset(handle_token)
except Exception as exc:
_log_trace_error(f"span_handle_reset:{name}", exc)
def trace_next_attempt(
target: Sample | TraceHandle | list[Sample | TraceHandle],
*,
attrs: dict[str, Any] | None = None,
):
try:
handles = _coerce_handles(target)
for handle in handles:
next_attempt = int(handle.carrier.get("attempt", 0)) + 1
handle.carrier["attempt"] = next_attempt
handle.attempt = next_attempt
attempt_attrs = {"attempt": next_attempt}
if attrs:
attempt_attrs.update(attrs)
trace_event(handle, "attempt_start", attrs=attempt_attrs)
if not handles:
return target
return handles[0] if len(handles) == 1 else handles
except Exception as exc:
_log_trace_error("next_attempt", exc)
return target
def trace_function(
name: str,
*,
target: str | None = None,
target_getter: Callable[..., Sample | TraceHandle | list[Sample | TraceHandle] | None] | None = None,
attrs_getter: Callable[..., dict[str, Any] | None] | None = None,
record_error: bool = True,
):
def decorator(fn):
if inspect.iscoroutinefunction(fn):
@functools.wraps(fn)
async def async_wrapper(*args, **kwargs):
resolved_target = _resolve_trace_function_target(
fn,
args,
kwargs,
target=target,
target_getter=target_getter,
)
if resolved_target is None:
return await fn(*args, **kwargs)
attrs = _resolve_trace_function_attrs(fn, args, kwargs, attrs_getter=attrs_getter)
with trace_span(resolved_target, name, attrs=attrs, record_error=record_error):
return await fn(*args, **kwargs)
return async_wrapper
@functools.wraps(fn)
def sync_wrapper(*args, **kwargs):
resolved_target = _resolve_trace_function_target(
fn,
args,
kwargs,
target=target,
target_getter=target_getter,
)
if resolved_target is None:
return fn(*args, **kwargs)
attrs = _resolve_trace_function_attrs(fn, args, kwargs, attrs_getter=attrs_getter)
with trace_span(resolved_target, name, attrs=attrs, record_error=record_error):
return fn(*args, **kwargs)
return sync_wrapper
return decorator
def _record_span_end(
span_records: list[tuple[TraceHandle, str]],
*,
name: str,
attrs: dict[str, Any] | None,
) -> list[dict[str, Any]]:
timestamp = time.time()
events = []
for handle, span_id in span_records:
events.append(
_append_event(
handle,
kind="span_end",
name=name,
timestamp=timestamp,
span_id=span_id,
attrs=attrs,
)
)
return events
def _append_event(
handle: TraceHandle,
*,
kind: str,
name: str,
timestamp: float,
attrs: dict[str, Any] | None = None,
span_id: str | None = None,
parent_span_id: str | None = None,
) -> dict[str, Any]:
event = {
"type": kind,
"name": name,
"ts": timestamp,
"trace_id": handle.trace_id,
"sample_id": handle.sample_id,
"group_id": handle.group_id,
"attempt": int(handle.carrier.get("attempt", handle.attempt)),
}
if span_id is not None:
event["span_id"] = span_id
if parent_span_id is not None:
event["parent_span_id"] = parent_span_id
if attrs:
event["attrs"] = dict(attrs)
handle.carrier["events"].append(event)
return event
def _coerce_handles(target: Sample | TraceHandle | list[Sample | TraceHandle]) -> list[TraceHandle]:
target = _adapt_trace_target(target)
if isinstance(target, TraceHandle):
return [target]
if isinstance(target, Sample):
return [bind_trace(target)]
if isinstance(target, list):
handles = []
for item in target:
handles.extend(_coerce_handles(item))
return handles
return []
def _get_current_parent_span_id(
trace_id: str,
*,
stack: tuple[tuple[str, str], ...] | None = None,
) -> str | None:
stack = _TRACE_STACK.get() if stack is None else stack
for current_trace_id, span_id in reversed(stack):
if current_trace_id == trace_id:
return span_id
return None
def _resolve_trace_function_target(
fn,
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
target: str | None,
target_getter: Callable[..., Sample | TraceHandle | list[Sample | TraceHandle] | None] | None,
):
try:
bound = inspect.signature(fn).bind_partial(*args, **kwargs)
except Exception as exc:
_log_trace_error(f"trace_function_bind:{getattr(fn, '__qualname__', fn)}", exc)
bound = None
if target is not None:
if bound is None or target not in bound.arguments:
logger.warning(
"trace_function target '%s' not found for %s; tracing disabled for this call",
target,
getattr(fn, "__qualname__", repr(fn)),
)
return None
resolved = _normalize_trace_target(bound.arguments.get(target))
if resolved is None:
logger.warning(
"trace_function target '%s' for %s is not a supported trace target; tracing disabled for this call",
target,
getattr(fn, "__qualname__", repr(fn)),
)
return resolved
if target_getter is not None:
try:
resolved = _normalize_trace_target(target_getter(*args, **kwargs))
return resolved
except Exception as exc:
_log_trace_error(f"trace_function_target_getter:{getattr(fn, '__qualname__', fn)}", exc)
return None
inferred = _infer_trace_target(bound.arguments.values() if bound is not None else args)
if inferred is not None:
warn_key = getattr(fn, "__module__", "") + "." + getattr(fn, "__qualname__", repr(fn))
if warn_key not in _TRACE_AUTO_INFER_WARNED:
_TRACE_AUTO_INFER_WARNED.add(warn_key)
logger.warning(
"trace_function auto-inferred target for %s; inference may be ambiguous, prefer explicit target=...",
getattr(fn, "__qualname__", repr(fn)),
)
return inferred
return _get_current_trace_target()
def _resolve_trace_function_attrs(
fn,
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
attrs_getter: Callable[..., dict[str, Any] | None] | None,
) -> dict[str, Any] | None:
if attrs_getter is None:
return None
try:
attrs = attrs_getter(*args, **kwargs)
if attrs is None:
return None
if isinstance(attrs, dict):
return attrs
logger.warning(
"trace_function attrs_getter for %s returned non-dict %s; ignoring attrs",
getattr(fn, "__qualname__", repr(fn)),
type(attrs).__name__,
)
return None
except Exception as exc:
_log_trace_error(f"trace_function_attrs_getter:{getattr(fn, '__qualname__', fn)}", exc)
return None
def _infer_trace_target(values) -> Sample | TraceHandle | list[Sample | TraceHandle] | None:
for value in values:
normalized = _normalize_trace_target(value)
if normalized is not None:
return normalized
return None
def _normalize_trace_target(value):
value = _adapt_trace_target(value)
if isinstance(value, (Sample, TraceHandle)):
return value
if isinstance(value, list) and value:
if all(_normalize_trace_target(item) is not None for item in value):
return value
return None
def _adapt_trace_target(value):
if value is None:
return None
if isinstance(value, (Sample, TraceHandle)):
return value
if isinstance(value, list):
return [_adapt_trace_target(item) for item in value]
if _looks_like_sample_box(value):
generation = getattr(value, "generation", None)
if generation:
return generation
return getattr(value, "prompt_sample", None)
return value
def _get_current_trace_target() -> TraceHandle | list[TraceHandle] | None:
handle_stack = _TRACE_HANDLE_STACK.get()
if not handle_stack:
return None
current_handles = list(handle_stack[-1])
if not current_handles:
return None
if len(current_handles) == 1:
return current_handles[0]
return current_handles
def _looks_like_sample_box(value: Any) -> bool:
cls = getattr(value, "__class__", None)
if cls is None or getattr(cls, "__name__", "") != "SampleBox":
return False
return hasattr(value, "prompt_sample") and hasattr(value, "generation")