"""从模型响应 / LangChain 消息中解析 token 用量。"""

from __future__ import annotations

from typing import Any

from langchain_core.messages import AIMessage


def _as_int(value: Any) -> int:
    try:
        if value is None:
            return 0
        return int(value)
    except (TypeError, ValueError):
        return 0


def _cache_from_details(details: Any) -> int:
    if not isinstance(details, dict):
        return 0
    return _as_int(
        details.get("cached_tokens")
        or details.get("cache_read")
        or details.get("cache_read_input_tokens")
    )


def parse_usage_dict(raw: dict[str, Any]) -> dict[str, int] | None:
    if not raw:
        return None

    inp = _as_int(
        raw.get("input_tokens")
        or raw.get("prompt_tokens")
        or raw.get("input")
    )
    out = _as_int(
        raw.get("output_tokens")
        or raw.get("completion_tokens")
        or raw.get("output")
    )
    cache = _as_int(
        raw.get("cache_read_input_tokens")
        or raw.get("cached_tokens")
        or raw.get("cache_tokens")
    )
    details = raw.get("input_token_details") or raw.get("prompt_tokens_details")
    if not cache:
        cache = _cache_from_details(details)

    total = _as_int(raw.get("total_tokens") or raw.get("total"))
    if not total:
        total = inp + out

    if inp == 0 and out == 0 and cache == 0 and total == 0:
        return None

    return {"input": inp, "output": out, "cache": cache, "total": total}


def usage_from_message(message: Any) -> dict[str, int] | None:
    if not isinstance(message, AIMessage):
        return None

    usage_meta = getattr(message, "usage_metadata", None)
    if isinstance(usage_meta, dict):
        parsed = parse_usage_dict(usage_meta)
        if parsed:
            return parsed

    response_meta = getattr(message, "response_metadata", None) or {}
    if isinstance(response_meta, dict):
        token_usage = response_meta.get("token_usage") or response_meta.get("usage")
        if isinstance(token_usage, dict):
            parsed = parse_usage_dict(token_usage)
            if parsed:
                return parsed
        parsed = parse_usage_dict(response_meta)
        if parsed:
            return parsed

    return None


def usage_from_chat_model_end_event(ev: dict[str, Any]) -> dict[str, int] | None:
    data = ev.get("data")
    if not isinstance(data, dict):
        return None
    output = data.get("output")
    if isinstance(output, AIMessage):
        return usage_from_message(output)
    generations = getattr(output, "generations", None)
    if generations:
        parts: list[dict[str, int]] = []
        for gen in generations:
            msg = getattr(gen, "message", None)
            parsed = usage_from_message(msg)
            if parsed:
                parts.append(parsed)
        if parts:
            return merge_usage_list(parts)
    return None


def empty_usage() -> dict[str, int]:
    return {"input": 0, "output": 0, "cache": 0, "total": 0}


def merge_usage(a: dict[str, int], b: dict[str, int]) -> dict[str, int]:
    return {
        "input": _as_int(a.get("input")) + _as_int(b.get("input")),
        "output": _as_int(a.get("output")) + _as_int(b.get("output")),
        "cache": _as_int(a.get("cache")) + _as_int(b.get("cache")),
        "total": _as_int(a.get("total")) + _as_int(b.get("total")),
    }


def merge_usage_list(parts: list[dict[str, int]]) -> dict[str, int]:
    acc = empty_usage()
    for item in parts:
        acc = merge_usage(acc, item)
    return acc