from __future__ import annotations

import json
import warnings
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
    trim_messages,
)
from langchain_core.messages.utils import count_tokens_approximately
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import CompiledStateGraph
from langgraph.runtime import Runtime
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.warnings import LangGraphDeprecatedSinceV10

from app.core.json_util import dumps_json
from app.services.token_usage import merge_usage, usage_from_message


_DEFAULT_SYSTEM_FALLBACK = (
    "You are a helpful assistant. Reply in the same language as the user when appropriate."
)

COMPRESS_SUMMARY_PREFIX = "[此前对话已自动压缩为摘要]"


def compress_summary_bridge_content(summary: str) -> str:
    return f"{COMPRESS_SUMMARY_PREFIX}\n{summary.strip()}"


def parse_compress_summary_content(text: str) -> str | None:
    if not text.startswith(COMPRESS_SUMMARY_PREFIX):
        return None
    body = text[len(COMPRESS_SUMMARY_PREFIX) :].lstrip("\n").strip()
    return body or None


def _as_dict(value: Any) -> dict[str, Any]:
    if isinstance(value, dict):
        return value
    if isinstance(value, str):
        text = value.strip()
        if text.startswith("{"):
            try:
                parsed = json.loads(text)
                if isinstance(parsed, dict):
                    return parsed
            except json.JSONDecodeError:
                pass
    return {}


def configurable_from_config(cfg: Any) -> dict[str, Any]:
    """LangGraph RunnableConfig.configurable 有时被序列化为非 dict,避免 conf.get 崩溃。"""
    if not isinstance(cfg, dict):
        return {}
    return _as_dict(cfg.get("configurable"))


def ai_config_from_configurable(conf: dict[str, Any]) -> dict[str, Any]:
    ai = conf.get("ai")
    return _as_dict(ai) if ai is not None else {}


@dataclass
class ChatInvokeContext:
    """graph.astream_events / ainvoke 的 run context,供动态 model 读取 AI 配置。"""

    ai: dict[str, Any] = field(default_factory=dict)
    llm_system_prompt: str = ""
    system_prompt: str = ""

    def to_configurable(self) -> dict[str, Any]:
        return {
            "ai": self.ai,
            "llm_system_prompt": self.llm_system_prompt,
            "system_prompt": self.system_prompt or self.llm_system_prompt,
        }


def _resolve_run_configurable(
    *,
    config: RunnableConfig | None = None,
    runtime: Runtime[ChatInvokeContext] | None = None,
) -> dict[str, Any]:
    """从 RunnableConfig 或 Runtime.context 取 configurable,避免 get_config()(Py3.10 异步下不可用)。"""
    if config is not None:
        conf = configurable_from_config(config)
        if conf:
            return conf
    ctx = getattr(runtime, "context", None) if runtime is not None else None
    if isinstance(ctx, ChatInvokeContext):
        return ctx.to_configurable()
    return {}


def _aimessage_has_tool_calls(msg: AIMessage) -> bool:
    if msg.tool_calls:
        return True
    extra = msg.additional_kwargs.get("tool_calls")
    return bool(extra)


def repair_messages_for_thinking_model(messages: list[BaseMessage]) -> list[BaseMessage]:
    """
    DeepSeek 等 thinking 模式:带 tool_calls 的 assistant 消息必须带 reasoning_content 字段。
    旧 checkpoint / 流式合并遗漏时补空字符串(API 实测接受),避免 400。
    """
    out: list[BaseMessage] = []
    for msg in messages:
        if isinstance(msg, AIMessage) and _aimessage_has_tool_calls(msg):
            if "reasoning_content" not in msg.additional_kwargs:
                msg = msg.model_copy(
                    update={
                        "additional_kwargs": {
                            **msg.additional_kwargs,
                            "reasoning_content": "",
                        }
                    }
                )
        out.append(msg)
    return out


class ThinkingChatOpenAI(ChatOpenAI):
    """
    兼容 OpenAI 兼容网关上的「思考 / thinking」模型(如 DeepSeek v4 等)。
    响应里的 reasoning_content 必须写回 AIMessage,并在下一轮请求中原样带回。
    """

    @staticmethod
    def _reasoning_for_request(msg: AIMessage, msg_dict: dict[str, Any]) -> str | None:
        if "reasoning_content" in msg.additional_kwargs:
            rc = msg.additional_kwargs["reasoning_content"]
            return rc if isinstance(rc, str) else str(rc)
        if _aimessage_has_tool_calls(msg) or msg_dict.get("tool_calls"):
            return ""
        return None

    def _inject_reasoning_into_payload(
        self, payload: dict[str, Any], messages: list[BaseMessage]
    ) -> dict[str, Any]:
        msg_dicts = payload.get("messages")
        if not msg_dicts or len(msg_dicts) != len(messages):
            return payload
        for msg_dict, msg in zip(msg_dicts, messages, strict=False):
            if not isinstance(msg, AIMessage):
                continue
            rc = self._reasoning_for_request(msg, msg_dict)
            if rc is not None:
                msg_dict["reasoning_content"] = rc
        return payload

    def _get_request_payload(
        self,
        messages: list[BaseMessage],
        *,
        stop: list[str] | None = None,
        **kwargs: Any,
    ) -> dict[str, Any]:
        payload = super()._get_request_payload(messages, stop=stop, **kwargs)
        return self._inject_reasoning_into_payload(payload, messages)

    def _create_chat_result(
        self,
        response: Any,
        generation_info: dict[str, Any] | None = None,
    ) -> ChatResult:
        result = super()._create_chat_result(response, generation_info)
        response_dict = (
            response
            if isinstance(response, dict)
            else response.model_dump(
                exclude={"choices": {"__all__": {"message": {"parsed"}}}}
            )
        )
        for gen, choice in zip(
            result.generations,
            response_dict.get("choices") or [],
            strict=False,
        ):
            msg_dict = choice.get("message") or {}
            rc = msg_dict.get("reasoning_content")
            if rc is not None:
                gen.message.additional_kwargs["reasoning_content"] = rc
        return result

    def _convert_chunk_to_generation_chunk(
        self,
        chunk: dict[str, Any],
        default_chunk_class: type,
        base_generation_info: dict[str, Any] | None,
    ) -> ChatGenerationChunk | None:
        gen = super()._convert_chunk_to_generation_chunk(
            chunk, default_chunk_class, base_generation_info
        )
        if gen is None:
            return gen
        nested = chunk.get("chunk")
        if isinstance(nested, dict):
            choices = chunk.get("choices") or nested.get("choices") or []
        else:
            choices = chunk.get("choices") or []
        if not choices:
            return gen
        choice = choices[0]
        delta = choice.get("delta") or {}
        rc = delta.get("reasoning_content")
        if rc is None and isinstance(choice.get("message"), dict):
            rc = choice["message"].get("reasoning_content")
        if rc is None:
            return gen
        msg = gen.message
        if not isinstance(msg, AIMessage):
            return gen
        prev = msg.additional_kwargs.get("reasoning_content") or ""
        if isinstance(rc, str):
            msg.additional_kwargs["reasoning_content"] = prev + rc
        else:
            msg.additional_kwargs["reasoning_content"] = rc
        return gen


def _build_llm(ai: dict[str, Any]) -> ThinkingChatOpenAI:
    api_key = (ai.get("api_key") or "").strip()
    if not api_key:
        raise ValueError("未配置 API Key,请先在系统设置中保存 AI 设置")

    kwargs: dict[str, Any] = {
        "model": ai.get("model") or "gpt-4o-mini",
        "temperature": float(ai.get("temperature", 0.7)),
        "max_tokens": int(ai.get("max_tokens", 8192)),
        "api_key": api_key,
        "streaming": True,
    }
    base_url = (ai.get("base_url") or "").strip()
    if base_url:
        kwargs["base_url"] = base_url
    return ThinkingChatOpenAI(**kwargs)


def compute_max_input_tokens(ai: dict[str, Any]) -> int:
    context = int(ai.get("context_length") or 128000)
    max_out = int(ai.get("max_tokens") or 8192)
    return max(4096, int(context * 0.85) - max_out)


def compute_compress_threshold_tokens(ai: dict[str, Any]) -> int:
    """历史(含 system)估算 token 达到该值即触发摘要压缩:配置里「上下文长度」的 75%。"""
    context = int(ai.get("context_length") or 128000)
    return max(4096, int(context * 0.75))


def estimate_messages_tokens(messages: list[BaseMessage], system_prompt: str) -> int:
    """与 trim_messages 一致的近似计数:system + 历史消息。"""
    system = SystemMessage(content=system_prompt or _DEFAULT_SYSTEM_FALLBACK)
    return int(count_tokens_approximately([system, *messages]))


def effective_tokens_for_compression(
    estimated: int,
    actual_input_tokens: int | None = None,
) -> int:
    """压缩判定取估算与本轮 API 真实 input 的较大值。"""
    actual = int(actual_input_tokens or 0)
    if actual <= 0:
        return estimated
    return max(estimated, actual)


def prepare_messages_for_llm(
    history: list[BaseMessage],
    system_prompt: str,
    ai: dict[str, Any],
) -> list[BaseMessage]:
    """在调用模型前裁剪上下文;不截断单条消息,避免破坏 tool_calls / ToolMessage 配对。"""
    history = repair_messages_for_thinking_model(history)
    system = SystemMessage(content=system_prompt or _DEFAULT_SYSTEM_FALLBACK)
    combined: list[BaseMessage] = [system, *history]
    return trim_messages(
        combined,
        max_tokens=compute_max_input_tokens(ai),
        token_counter="approximate",
        strategy="last",
        allow_partial=False,
        include_system=True,
        start_on="human",
    )


def reasoning_text(message: BaseMessage) -> str:
    """AIMessage 上累积的 reasoning_content(思考模型)。"""
    if not isinstance(message, AIMessage):
        return ""
    rc = (getattr(message, "additional_kwargs", None) or {}).get("reasoning_content")
    if rc is None:
        return ""
    return str(rc)


def chunk_reasoning(chunk: BaseMessage) -> str:
    """流式 chunk 上的 reasoning_content 增量。"""
    ak = getattr(chunk, "additional_kwargs", None) or {}
    if not isinstance(ak, dict):
        return ""
    rc = ak.get("reasoning_content")
    if rc is None:
        return ""
    return str(rc)


def chunk_text(chunk: BaseMessage) -> str:
    """OpenAI Chat 风格:content 为 str 或 list[{type,text}, ...]。"""
    content = getattr(chunk, "content", "") or ""
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts: list[str] = []
        for block in content:
            if isinstance(block, str):
                parts.append(block)
            elif isinstance(block, dict):
                parts.append(str(block.get("text") or block.get("content") or ""))
        return "".join(parts)
    return str(content) if content else ""


def _normalize_tool_args(raw: Any) -> dict[str, Any]:
    if isinstance(raw, dict):
        return raw
    if isinstance(raw, str):
        text = raw.strip()
        if not text:
            return {}
        try:
            parsed = json.loads(text)
            if isinstance(parsed, dict):
                return parsed
        except json.JSONDecodeError:
            pass
        return {"input": raw}
    return {}


def _parse_one_tool_call(call: Any) -> tuple[str, str, dict[str, Any]]:
    if isinstance(call, dict):
        name = str(call.get("name") or "")
        cid = str(call.get("id") or "").strip()
        args = _normalize_tool_args(call.get("args"))
        if not args and call.get("arguments") is not None:
            args = _normalize_tool_args(call.get("arguments"))
        return name, cid, args
    name = str(getattr(call, "name", "") or "")
    cid = str(getattr(call, "id", "") or "").strip()
    args = _normalize_tool_args(getattr(call, "args", {}))
    return name, cid, args


def _parse_raw_openai_tool_calls(extra: Any) -> list[dict[str, Any]]:
    if not isinstance(extra, list):
        return []
    out: list[dict[str, Any]] = []
    for raw in extra:
        if not isinstance(raw, dict):
            continue
        fn = raw.get("function") if isinstance(raw.get("function"), dict) else {}
        name = str(fn.get("name") or raw.get("name") or "")
        cid = str(raw.get("id") or "").strip()
        args = _normalize_tool_args(fn.get("arguments"))
        if not args and raw.get("args") is not None:
            args = _normalize_tool_args(raw.get("args"))
        out.append({"id": cid, "name": name, "args": args})
    return out


def extract_tool_calls(message: AIMessage) -> list[dict[str, Any]]:
    """AIMessage.tool_calls → 统一 dict 列表;兼容 additional_kwargs 中的原始 tool_calls。"""
    seen: set[str] = set()
    out: list[dict[str, Any]] = []

    def _append(call: Any) -> None:
        name, cid, args = _parse_one_tool_call(call)
        key = cid or f"__anon_{len(out)}__"
        if key in seen:
            return
        seen.add(key)
        out.append({"id": cid, "name": name, "args": args})

    for call in getattr(message, "tool_calls", None) or []:
        _append(call)

    extra = (getattr(message, "additional_kwargs", None) or {}).get("tool_calls")
    for call in _parse_raw_openai_tool_calls(extra):
        _append(call)

    return out


def _tool_result_ok(result: str) -> bool:
    ok = True
    try:
        parsed = json.loads(result)
        if isinstance(parsed, dict) and parsed.get("ok") is False:
            ok = False
    except json.JSONDecodeError:
        pass
    return ok


def _tool_dict(name: str, cid: str, args: dict[str, Any], result: str) -> dict[str, Any]:
    return {
        "id": cid,
        "name": name,
        "args": args,
        "result": result,
        "status": "done",
        "ok": _tool_result_ok(result),
    }


def _collect_tool_results(messages: list[BaseMessage], start: int) -> tuple[dict[str, str], list[str], int]:
    results_by_id: dict[str, str] = {}
    ordered: list[str] = []
    j = start
    n = len(messages)
    while j < n and isinstance(messages[j], ToolMessage):
        tm = messages[j]
        cid = str(getattr(tm, "tool_call_id", "") or "")
        body = chunk_text(tm)
        ordered.append(body)
        if cid:
            results_by_id[cid] = body
        j += 1
    return results_by_id, ordered, j


def _tool_result_for_call(
    cid: str,
    idx: int,
    results_by_id: dict[str, str],
    ordered_results: list[str],
) -> str:
    if cid and cid in results_by_id:
        return results_by_id[cid]
    if idx < len(ordered_results):
        return ordered_results[idx]
    return results_by_id.get(cid, "")


def _append_reasoning_segment(entry: dict[str, Any], reasoning: str) -> None:
    """每轮思考单独成段,避免多轮思考挤在同一组件。"""
    reasoning = (reasoning or "").strip()
    if not reasoning:
        return
    entry.setdefault("segments", [])
    segs: list[dict[str, Any]] = entry["segments"]
    prev_all = (entry.get("reasoning") or "").strip()
    entry["reasoning"] = f"{prev_all}\n\n{reasoning}".strip() if prev_all else reasoning
    if segs and segs[-1].get("type") == "reasoning":
        prev = (str(segs[-1].get("content") or "")).strip()
        segs[-1]["content"] = f"{prev}\n\n{reasoning}".strip() if prev else reasoning
    else:
        segs.append({"type": "reasoning", "content": reasoning})


def _append_text_segment(entry: dict[str, Any], text: str) -> None:
    text = (text or "").strip()
    if not text:
        return
    entry.setdefault("segments", [])
    entry["segments"].append({"type": "text", "content": text})
    prev = (entry.get("content") or "").strip()
    entry["content"] = f"{prev}\n\n{text}".strip() if prev else text


def _merge_assistant_text(entry: dict[str, Any], text: str, reasoning: str = "") -> None:
    _append_reasoning_segment(entry, reasoning)
    _append_text_segment(entry, text)


def _attach_usage(entry: dict[str, Any], message: AIMessage) -> None:
    usage = usage_from_message(message)
    if not usage:
        return
    if entry.get("usage"):
        entry["usage"] = merge_usage(entry["usage"], usage)
    else:
        entry["usage"] = usage


def _assistant_entry_with_tools(
    text: str,
    reasoning: str,
    tools_block: list[dict[str, Any]],
) -> dict[str, Any]:
    entry: dict[str, Any] = {
        "role": "assistant",
        "tools": tools_block,
        "segments": [],
    }
    _append_reasoning_segment(entry, reasoning)
    for tool in tools_block:
        entry["segments"].append({"type": "tool", "tool": dict(tool)})
    _append_text_segment(entry, text)
    return entry


def messages_to_items(messages: list[BaseMessage]) -> list[dict[str, Any]]:
    """转为前端可渲染结构;assistant 用 segments 保持正文与工具的时间顺序。"""
    items: list[dict[str, Any]] = []
    i = 0
    n = len(messages)
    while i < n:
        message = messages[i]
        if isinstance(message, HumanMessage):
            text = chunk_text(message).strip()
            if text:
                summary_body = parse_compress_summary_content(text)
                if summary_body is not None:
                    items.append({"role": "summary", "content": summary_body})
                else:
                    items.append({"role": "user", "content": text})
            i += 1
            continue
        if isinstance(message, AIMessage):
            text = chunk_text(message).strip()
            reasoning = reasoning_text(message).strip()
            calls = extract_tool_calls(message)
            if calls:
                tools_block: list[dict[str, Any]] = []
                results_by_id, ordered_results, j = _collect_tool_results(messages, i + 1)
                for idx, call in enumerate(calls):
                    name, cid, args = _parse_one_tool_call(call)
                    tools_block.append(
                        _tool_dict(
                            name,
                            cid,
                            args,
                            _tool_result_for_call(cid, idx, results_by_id, ordered_results),
                        )
                    )
                entry = _assistant_entry_with_tools(text, reasoning, tools_block)
                _attach_usage(entry, message)
                i = j
                while i < n and isinstance(messages[i], AIMessage):
                    follow = messages[i]
                    follow_text = chunk_text(follow).strip()
                    follow_reasoning = reasoning_text(follow).strip()
                    follow_calls = extract_tool_calls(follow)
                    if follow_calls:
                        _append_reasoning_segment(entry, follow_reasoning)
                        results_by_id, ordered_results, j = _collect_tool_results(messages, i + 1)
                        base_idx = len(tools_block)
                        for off, call in enumerate(follow_calls):
                            name, cid, args = _parse_one_tool_call(call)
                            td = _tool_dict(
                                name,
                                cid,
                                args,
                                _tool_result_for_call(
                                    cid, base_idx + off, results_by_id, ordered_results
                                ),
                            )
                            tools_block.append(td)
                            entry.setdefault("segments", []).append(
                                {"type": "tool", "tool": td}
                            )
                        entry["tools"] = tools_block
                        i = j
                        continue
                    if follow_text or follow_reasoning:
                        _merge_assistant_text(entry, follow_text, follow_reasoning)
                        _attach_usage(entry, follow)
                        i += 1
                        continue
                    # 跳过无正文/无工具的空 AIMessage,继续合并同轮后续回复
                    i += 1
                    continue
                items.append(entry)
                continue
            if text or reasoning:
                item: dict[str, Any] = {"role": "assistant", "segments": []}
                _append_reasoning_segment(item, reasoning)
                _append_text_segment(item, text)
                _attach_usage(item, message)
                items.append(item)
            i += 1
            continue
        i += 1
    return items


def make_chat_prompt():
    """
    create_react_agent 的 prompt:从完整 state['messages'] 生成模型输入。
    不写 llm_input_messages,避免 ReAct 循环里工具结果已进 checkpoint 但模型仍读到旧裁剪列表。
    """

    def chat_prompt(
        state: dict[str, Any],
        *,
        config: RunnableConfig,
        runtime: Runtime[ChatInvokeContext] | None = None,
    ) -> list[BaseMessage]:
        conf = _resolve_run_configurable(config=config, runtime=runtime)
        ai = ai_config_from_configurable(conf)
        llm_sys = str(conf.get("llm_system_prompt") or conf.get("system_prompt") or "")
        raw = list(state.get("messages") or [])
        return prepare_messages_for_llm(raw, llm_sys, ai)

    return chat_prompt


def make_dynamic_chat_model(bound_tools: list[Any]):
    """create_react_agent 动态 model:仅返回 ChatOpenAI(+bind_tools);裁剪由 pre_model_hook 负责。"""

    def resolve_llm(
        state: dict[str, Any],
        runtime: Runtime[ChatInvokeContext],
    ) -> Any:
        del state
        conf = _resolve_run_configurable(runtime=runtime)
        ai = ai_config_from_configurable(conf)
        llm = _build_llm(ai)
        if bound_tools:
            return llm.bind_tools(bound_tools)
        return llm

    return resolve_llm


def format_tool_error_for_model(exc: Exception) -> str:
    """工具未捕获异常时写入 ToolMessage,供模型读取并自行纠正。"""
    payload: dict[str, Any] = {
        "ok": False,
        "error": str(exc) or type(exc).__name__,
        "error_type": type(exc).__name__,
    }
    return dumps_json(payload)


def build_chat_react_agent(
    tools: Sequence[Any],
    checkpointer: BaseCheckpointSaver,
) -> CompiledStateGraph:
    """create_react_agent:prompt 裁剪输入;checkpoint 保留完整 messages;tools 节点循环。"""
    tool_list = list(tools)
    tools_node: ToolNode | list[Any] = (
        ToolNode(tool_list, handle_tool_errors=format_tool_error_for_model)
        if tool_list
        else []
    )
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", LangGraphDeprecatedSinceV10)
        return create_react_agent(
            make_dynamic_chat_model(tool_list),
            tools_node,
            prompt=make_chat_prompt(),
            context_schema=ChatInvokeContext,
            checkpointer=checkpointer,
            version="v2",
            name="compilot_chat_agent",
        )