"""从模型响应 / LangChain 消息中解析 token 用量。"""
from __future__ import annotations
from typing import Any
from langchain_core.messages import AIMessage
def _as_int(value: Any) -> int:
try:
if value is None:
return 0
return int(value)
except (TypeError, ValueError):
return 0
def _cache_from_details(details: Any) -> int:
if not isinstance(details, dict):
return 0
return _as_int(
details.get("cached_tokens")
or details.get("cache_read")
or details.get("cache_read_input_tokens")
)
def parse_usage_dict(raw: dict[str, Any]) -> dict[str, int] | None:
if not raw:
return None
inp = _as_int(
raw.get("input_tokens")
or raw.get("prompt_tokens")
or raw.get("input")
)
out = _as_int(
raw.get("output_tokens")
or raw.get("completion_tokens")
or raw.get("output")
)
cache = _as_int(
raw.get("cache_read_input_tokens")
or raw.get("cached_tokens")
or raw.get("cache_tokens")
)
details = raw.get("input_token_details") or raw.get("prompt_tokens_details")
if not cache:
cache = _cache_from_details(details)
total = _as_int(raw.get("total_tokens") or raw.get("total"))
if not total:
total = inp + out
if inp == 0 and out == 0 and cache == 0 and total == 0:
return None
return {"input": inp, "output": out, "cache": cache, "total": total}
def usage_from_message(message: Any) -> dict[str, int] | None:
if not isinstance(message, AIMessage):
return None
usage_meta = getattr(message, "usage_metadata", None)
if isinstance(usage_meta, dict):
parsed = parse_usage_dict(usage_meta)
if parsed:
return parsed
response_meta = getattr(message, "response_metadata", None) or {}
if isinstance(response_meta, dict):
token_usage = response_meta.get("token_usage") or response_meta.get("usage")
if isinstance(token_usage, dict):
parsed = parse_usage_dict(token_usage)
if parsed:
return parsed
parsed = parse_usage_dict(response_meta)
if parsed:
return parsed
return None
def usage_from_chat_model_end_event(ev: dict[str, Any]) -> dict[str, int] | None:
data = ev.get("data")
if not isinstance(data, dict):
return None
output = data.get("output")
if isinstance(output, AIMessage):
return usage_from_message(output)
generations = getattr(output, "generations", None)
if generations:
parts: list[dict[str, int]] = []
for gen in generations:
msg = getattr(gen, "message", None)
parsed = usage_from_message(msg)
if parsed:
parts.append(parsed)
if parts:
return merge_usage_list(parts)
return None
def empty_usage() -> dict[str, int]:
return {"input": 0, "output": 0, "cache": 0, "total": 0}
def merge_usage(a: dict[str, int], b: dict[str, int]) -> dict[str, int]:
return {
"input": _as_int(a.get("input")) + _as_int(b.get("input")),
"output": _as_int(a.get("output")) + _as_int(b.get("output")),
"cache": _as_int(a.get("cache")) + _as_int(b.get("cache")),
"total": _as_int(a.get("total")) + _as_int(b.get("total")),
}
def merge_usage_list(parts: list[dict[str, int]]) -> dict[str, int]:
acc = empty_usage()
for item in parts:
acc = merge_usage(acc, item)
return acc