"""Rate card loader and per-span $-cost calculator.
Rates are expressed in USD per 1M tokens. The default card lives next to
this file (``rate_cards.yaml``); override via ``OGMEM_PERF_RATE_CARD=/path``.
Unknown models fall through to the ``default`` entry within each family.
If neither the model nor a ``default`` entry exists, a ``KeyError`` is
raised — no silent fallback (per CLAUDE.md §工程原则).
"""
from __future__ import annotations
import logging
import os
from functools import lru_cache
from pathlib import Path
from typing import Any
logger = logging.getLogger("ogmem.perf.rate_cards")
_DEFAULT_CARD_PATH = Path(__file__).parent / "rate_cards.yaml"
def _load_yaml(path: Path) -> dict[str, Any]:
try:
import yaml
except ImportError as exc:
raise RuntimeError(
"PyYAML is required for perf.rate_cards; install pyyaml"
) from exc
with path.open("r", encoding="utf-8") as fh:
return yaml.safe_load(fh) or {}
@lru_cache(maxsize=4)
def _load_cached(resolved_path: str) -> dict[str, Any]:
"""Load and cache a rate card from an already-resolved path string."""
p = Path(resolved_path)
if not p.exists():
raise FileNotFoundError(f"rate card not found: {p}")
data = _load_yaml(p)
if "llm" not in data and "embedding" not in data:
raise ValueError(f"invalid rate card {p}: missing llm/embedding families")
return data
def load_rate_card(path: str | None = None) -> dict[str, Any]:
"""Load a rate card from YAML.
Args:
path: Explicit path, or ``None`` to honour ``OGMEM_PERF_RATE_CARD``
and finally the packaged default.
Returns:
Parsed dict with ``llm`` and ``embedding`` families.
"""
resolved = str(Path(path or os.environ.get("OGMEM_PERF_RATE_CARD") or _DEFAULT_CARD_PATH))
return _load_cached(resolved)
def _lookup(card: dict[str, Any], family: str, model: str | None) -> dict[str, float]:
entries = card.get(family, {}) or {}
if model and model in entries:
return entries[model]
if model and "/" in model:
suffix = model.split("/", 1)[1]
if suffix in entries:
return entries[suffix]
if "default" in entries:
return entries["default"]
raise KeyError(
f"rate card family={family!r} has no entry for model={model!r} and no 'default'"
)
def _per_million(rate: float, tokens: int) -> float:
if not tokens:
return 0.0
if tokens < 0:
raise ValueError(f"negative token count: {tokens}")
return float(tokens) * float(rate) / 1_000_000.0
def compute_cost(
llm_tokens: dict[str, int],
embed_tokens: dict[str, int],
llm_model: str | None,
embed_model: str | None,
card: dict[str, Any] | None = None,
) -> dict[str, float]:
"""Compute $-cost breakdown for a span's token deltas.
Args:
llm_tokens: dict with ``input_tokens`` / ``output_tokens`` /
``cache_read`` / ``cache_write``.
embed_tokens: dict with ``embed_tokens``.
llm_model: resolved LLM model name or ``None``.
embed_model: resolved embedding model name or ``None``.
card: pre-loaded rate card; ``None`` loads the default.
Returns:
``{"llm": float, "embedding": float, "total": float}`` in USD.
"""
card = card or load_rate_card()
llm_cost = 0.0
if llm_tokens:
rates = _lookup(card, "llm", llm_model)
llm_cost += _per_million(rates.get("input", 0.0), llm_tokens.get("input_tokens", 0))
llm_cost += _per_million(rates.get("output", 0.0), llm_tokens.get("output_tokens", 0))
if "cache_read" in rates:
llm_cost += _per_million(rates["cache_read"], llm_tokens.get("cache_read", 0))
cache_write_tokens = llm_tokens.get("cache_write", 0)
if cache_write_tokens > 0 and "cache_write" in rates:
llm_cost += _per_million(rates["cache_write"], cache_write_tokens)
embed_cost = 0.0
if embed_tokens:
rates = _lookup(card, "embedding", embed_model)
embed_cost += _per_million(rates.get("input", 0.0), embed_tokens.get("embed_tokens", 0))
return {
"llm": round(llm_cost, 8),
"embedding": round(embed_cost, 8),
"total": round(llm_cost + embed_cost, 8),
}