"""Per-context token attribution for perf spans.
Historical design (deprecated)
------------------------------
``snapshot_service`` used to read ``service._llm.token_tracker`` and
``service._embedder.token_tracker`` directly via reflection, then
:func:`diff` subtracted two snapshots to attribute tokens to a span.
That approach broke under concurrency: ``token_tracker`` is a process-wide
cumulative counter, so any thread that completed *after* siblings would
see those siblings' tokens included in its delta.
Current design (bucket-based)
-----------------------------
:func:`snapshot_service` now pushes a fresh
:class:`providers.token_tracker.TokenBucket` onto a ``ContextVar`` stack
maintained by ``providers.token_tracker``. Every provider call records
into the global counter *and* into every bucket on the active context's
stack. Each thread / asyncio task therefore sees only the LLM and
embedding work performed inside its own perf span, regardless of how
many concurrent extractions or compose calls run alongside it.
The function returns an opaque dict that contains the bucket reference
and the ``ContextVar`` reset token. :func:`diff` consumes that dict to
pop the bucket and render the accumulated totals in the same shape used
by report renderers, rate cards, and downstream consumers.
The caller surface is unchanged on purpose: ``Recorder.snapshot_tokens``
calls this module's :func:`snapshot_service`, and
``Recorder.finalize_tokens`` calls :func:`diff(after, before)`. Existing
HTTP handlers and the ``_background_extract_write`` instrumentation
continue to work without modification.
"""
from __future__ import annotations
import contextvars
import logging
from typing import Any
from providers.token_tracker import TokenBucket, current_buckets, pop_bucket, push_bucket
logger = logging.getLogger("ogmem.perf.attribution")
_BUCKET = "_bucket"
_RESET_TOKEN = "_reset_token"
def _model_of(obj: Any) -> str | None:
"""Best-effort model identifier from a provider object."""
if obj is None:
return None
model = getattr(obj, "model", None)
if isinstance(model, str) and model:
return model
return obj.__class__.__name__
def snapshot_service(service: Any) -> dict[str, Any]:
"""Open a new attribution scope.
Pushes a fresh :class:`TokenBucket` onto the active context's bucket
stack. The returned dict is opaque to callers — pass it as the
``before`` argument to :func:`diff` to close the scope and read the
accumulated totals.
The ``service`` argument is retained for API compatibility and used
only as a fallback for model-name attribution; the per-call model
identifier flows through :class:`providers.token_tracker.TokenTracker`
so providers that mix model versions remain accurate.
"""
bucket, token = push_bucket()
fallback_llm = _model_of(getattr(service, "_llm", None))
fallback_embed = _model_of(getattr(service, "_embedder", None))
return {
_BUCKET: bucket,
_RESET_TOKEN: token,
"llm_model_fallback": fallback_llm,
"embed_model_fallback": fallback_embed,
}
def _close(scope: dict[str, Any] | None) -> None:
"""Pop the bucket associated with ``scope`` if present."""
if not scope:
return
token = scope.get(_RESET_TOKEN)
if token is None:
return
try:
pop_bucket(token)
except (LookupError, ValueError) as exc:
logger.debug("pop_bucket skipped (token from different context): %s", exc)
def diff(after: dict[str, Any], before: dict[str, Any]) -> dict[str, Any]:
"""Close an attribution scope and return its accumulated totals.
The legacy signature is preserved: ``before`` is the dict returned
by :func:`snapshot_service`, ``after`` is whatever the caller passed
(it may be a redundant snapshot opened just before this call). Both
dicts have their buckets popped to keep the stack clean.
Returns a dict shaped like::
{
"llm": {... per-bucket llm counters ...},
"embed": {... per-bucket embed counters ...},
"llm_model": "gpt-4o-mini",
"embed_model": "text-embedding-3-small",
}
"""
bucket: TokenBucket | None = None
if before:
bucket = before.get(_BUCKET)
_close(after)
_close(before)
if bucket is None:
return {
"llm": {},
"embed": {},
"llm_model": None,
"embed_model": None,
}
rendered = bucket.to_attribution_dict()
if not rendered.get("llm_model"):
rendered["llm_model"] = (before or {}).get("llm_model_fallback")
if not rendered.get("embed_model"):
rendered["embed_model"] = (before or {}).get("embed_model_fallback")
return rendered
def token_source(snapshot: dict[str, Any]) -> str:
"""Classify where the token attribution came from.
Returns ``"tracker"`` when at least one provider call was recorded
into the bucket, ``"missing"`` otherwise. The name is preserved for
backward compatibility with consumers that branch on this value.
"""
llm = snapshot.get("llm") or {}
embed = snapshot.get("embed") or {}
has_llm = bool(llm.get("llm_calls") or llm.get("input_tokens") or llm.get("output_tokens"))
has_embed = bool(embed.get("embed_calls") or embed.get("embed_tokens"))
return "tracker" if (has_llm or has_embed) else "missing"
def detect_counter_reset(after: dict[str, Any], before: dict[str, Any], field: str = "input_tokens") -> bool:
"""Backward-compatible no-op.
The bucket model has no notion of a "reset" — each bucket starts
empty by construction. This function previously detected upstream
token-counter rollovers; it now always returns ``False`` and exists
only so callers that imported it continue to work.
"""
return False
__all__ = [
"snapshot_service",
"diff",
"token_source",
"detect_counter_reset",
]