"""LangGraph ReAct 对话运行时(含历史压缩)。"""

from __future__ import annotations

import asyncio
import json
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any

import aiosqlite
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    RemoveMessage,
    ToolMessage,
)
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.graph.state import CompiledStateGraph

from app.core.json_util import dumps_json
from app.core.paths import CHECKPOINT_DB, ensure_data_dirs
from app.graph.chat_graph import (
    ChatInvokeContext,
    _as_dict,
    _build_llm,
    build_chat_react_agent,
    chunk_reasoning,
    chunk_text,
    compress_summary_bridge_content,
    compute_compress_threshold_tokens,
    effective_tokens_for_compression,
    estimate_messages_tokens,
    extract_tool_calls,
    messages_to_items,
)
from app.services import conversation_store
from app.services.chat_context import resolve_chat_tools, resolve_effective_system_prompt
from app.services.token_usage import empty_usage, merge_usage_list, usage_from_chat_model_end_event

_log = logging.getLogger(__name__)

_GRAPH_NODE = "agent"

_COMPRESS_KEEP_RAW_MESSAGES = 6


@dataclass(frozen=True)
class _CompressionPlan:
    head: list[BaseMessage]
    tail: list[BaseMessage]
    hint: str


def _tool_result_preview(text: str, limit: int = 4000) -> str:
    if len(text) <= limit:
        return text
    return text[: limit - 3] + "..."


def _stream_tool_call_id(ev: dict[str, Any]) -> str:
    """流式工具事件 ID:优先 AI 的 tool_call_id,与 checkpoint 中 ToolMessage 对齐。"""
    data = _stream_event_data(ev)
    inp = data.get("input")
    if isinstance(inp, dict):
        tcid = str(inp.get("tool_call_id") or "").strip()
        if tcid:
            return tcid
    tcid = str(data.get("tool_call_id") or "").strip()
    if tcid:
        return tcid
    return str(ev.get("run_id") or "")


def _stream_event_data(ev: dict[str, Any]) -> dict[str, Any]:
    return _as_dict(ev.get("data"))


def _coerce_tool_result(out: Any) -> str:
    if out is None:
        return ""
    if isinstance(out, str):
        return out
    if isinstance(out, (dict, list)):
        return dumps_json(out)
    content = getattr(out, "content", None)
    if isinstance(content, str):
        return content
    if isinstance(content, (dict, list)):
        return dumps_json(content)
    return str(out)


def _tool_args_from_stream_input(inp: Any) -> dict[str, Any]:
    if isinstance(inp, dict):
        return {k: v for k, v in inp.items() if k != "tool_call_id"}
    return {}


def _react_graph_cache_key(_thread: dict[str, Any] | None) -> str:
    return "__plain__"


def _tools_bind_token(tools: list[Any]) -> str:
    parts: list[str] = []
    for t in tools:
        fn = getattr(t, "func", None) or getattr(t, "coroutine", None)
        parts.append(f"{getattr(t, 'name', '')}:{id(fn)}")
    return "|".join(parts) if parts else "_none_"


def _is_valid_compression_tail(messages: list[BaseMessage], start: int) -> bool:
    """保留的 tail 必须是可发给模型的合法后缀(tool 与 tool_calls 成对)。"""
    if start < 0 or start >= len(messages):
        return False
    if isinstance(messages[start], ToolMessage):
        return False
    i = start
    n = len(messages)
    while i < n:
        msg = messages[i]
        if isinstance(msg, HumanMessage):
            i += 1
            continue
        if isinstance(msg, AIMessage):
            calls = extract_tool_calls(msg)
            i += 1
            if not calls:
                continue
            required = {str(c.get("id") or "") for c in calls if c.get("id")}
            seen: set[str] = set()
            while i < n and isinstance(messages[i], ToolMessage):
                cid = str(getattr(messages[i], "tool_call_id", "") or "")
                if cid:
                    seen.add(cid)
                i += 1
            if required and not required.issubset(seen):
                return False
            continue
        if isinstance(msg, ToolMessage):
            return False
        i += 1
    return True


def _find_safe_compression_split(raw: list[BaseMessage], desired_start: int) -> int | None:
    n = len(raw)
    if n < 2:
        return None
    desired_start = max(1, min(desired_start, n - 1))
    candidates: list[int] = []
    for offset in range(n):
        for delta in (offset, -offset):
            candidate = desired_start + delta
            if 1 <= candidate < n and candidate not in candidates:
                candidates.append(candidate)
    for start in candidates:
        if _is_valid_compression_tail(raw, start):
            return start
    return None


def _split_messages_for_compression(
    raw: list[BaseMessage],
    *,
    keep_raw: int = _COMPRESS_KEEP_RAW_MESSAGES,
) -> tuple[list[BaseMessage], list[BaseMessage]] | None:
    """拆成待压缩 head 与保留原文 tail;切分点避开 tool / tool_calls 中间。"""
    if len(raw) < 2:
        return None
    tail_size = min(keep_raw, len(raw) - 1)
    desired_start = len(raw) - tail_size
    start = _find_safe_compression_split(raw, desired_start)
    if start is None:
        return None
    head = raw[:start]
    tail = raw[start:]
    if not head:
        return None
    return head, tail


def _tool_call_names(message: AIMessage) -> list[str]:
    names: list[str] = []
    for call in message.tool_calls or []:
        if isinstance(call, dict):
            name = str(call.get("name") or "").strip()
        else:
            name = str(getattr(call, "name", "") or "").strip()
        if name:
            names.append(name)
    return names


def format_messages_for_summary(messages: list[BaseMessage], *, max_each: int = 800) -> str:
    lines: list[str] = []
    for message in messages:
        if isinstance(message, HumanMessage):
            text = chunk_text(message).strip()
            if text:
                lines.append(f"用户: {text[:max_each]}")
        elif isinstance(message, AIMessage):
            tool_names = _tool_call_names(message)
            if tool_names:
                lines.append(f"助手[调用工具]: {', '.join(tool_names)}")
            text = chunk_text(message).strip()
            if text:
                lines.append(f"助手: {text[:max_each]}")
        elif isinstance(message, ToolMessage):
            name = str(getattr(message, "name", None) or "tool")
            body = str(message.content or "")[:max_each]
            lines.append(f"工具[{name}]: {body}")
    return "\n".join(lines)


class ChatRuntime:
    def __init__(self) -> None:
        self._conn: aiosqlite.Connection | None = None
        self._saver: Any = None
        self._react_graph_cache: dict[str, CompiledStateGraph] = {}

    @property
    def ready(self) -> bool:
        return bool(self._react_graph_cache)

    def _get_cached_react_graph(
        self, thread: dict[str, Any] | None, tools: list[Any]
    ) -> CompiledStateGraph:
        if self._saver is None:
            raise RuntimeError("ChatRuntime 尚未初始化")
        key = _react_graph_cache_key(thread)
        token = _tools_bind_token(tools)
        graph = self._react_graph_cache.get(key)
        if graph is None or getattr(graph, "_cp_tools_token", None) != token:
            graph = build_chat_react_agent(tools, self._saver)
            graph._cp_tools_token = token  # type: ignore[attr-defined]
            self._react_graph_cache[key] = graph
        return graph

    async def start(self) -> None:
        if self._react_graph_cache:
            return
        ensure_data_dirs()
        from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver

        self._conn = await aiosqlite.connect(str(CHECKPOINT_DB))
        self._saver = AsyncSqliteSaver(self._conn)
        await self._saver.setup()
        self._react_graph_cache["__plain__"] = build_chat_react_agent([], self._saver)

    async def stop(self) -> None:
        if self._conn is not None:
            await self._conn.close()
        self._conn = None
        self._saver = None
        self._react_graph_cache.clear()

    @staticmethod
    def make_config(thread_id: str, ai: dict[str, Any], system_prompt: str) -> dict[str, Any]:
        return {
            "configurable": {
                "thread_id": thread_id,
                "ai": ai,
                "system_prompt": system_prompt,
            }
        }

    async def _snapshot_messages(
        self, thread_id: str, ai: dict[str, Any], system_prompt: str
    ) -> list[BaseMessage]:
        thread = conversation_store.get_thread(thread_id)
        tools = resolve_chat_tools(thread)
        graph = self._get_cached_react_graph(thread, tools)
        config = self.make_config(thread_id, ai, system_prompt)
        snapshot = await graph.aget_state(config)
        return list((snapshot.values or {}).get("messages") or [])

    async def get_messages(
        self, thread_id: str, ai: dict[str, Any], system_prompt: str
    ) -> list[dict[str, Any]]:
        messages = await self._snapshot_messages(thread_id, ai, system_prompt)
        return messages_to_items(messages)

    async def replace_thread_messages(
        self,
        thread_id: str,
        ai: dict[str, Any],
        system_prompt: str,
        new_messages: list[BaseMessage],
    ) -> None:
        thread = conversation_store.get_thread(thread_id)
        tools = resolve_chat_tools(thread)
        graph = self._get_cached_react_graph(thread, tools)
        config = self.make_config(thread_id, ai, system_prompt)
        await graph.aupdate_state(
            config,
            {
                "messages": [
                    RemoveMessage(id=REMOVE_ALL_MESSAGES),
                    *new_messages,
                ]
            },
            as_node=_GRAPH_NODE,
        )

    async def plan_compression(
        self,
        thread_id: str,
        ai: dict[str, Any],
        system_prompt: str,
        *,
        actual_input_tokens: int | None = None,
    ) -> _CompressionPlan | None:
        thread = conversation_store.get_thread(thread_id)
        effective_prompt = resolve_effective_system_prompt(thread, system_prompt)
        raw = await self._snapshot_messages(thread_id, ai, system_prompt)
        if len(raw) < 2:
            return None

        threshold = compute_compress_threshold_tokens(ai)
        estimated = estimate_messages_tokens(raw, effective_prompt)
        tokens = effective_tokens_for_compression(estimated, actual_input_tokens)
        if tokens < threshold:
            return None

        split = _split_messages_for_compression(raw)
        if split is None:
            return None
        head, tail = split

        blob = format_messages_for_summary(head)
        if not blob.strip():
            return None

        hint = "保留对渗透测试任务有用的结论、待办、目标范围与关键发现,不要编造"
        return _CompressionPlan(head=head, tail=tail, hint=hint)

    async def maybe_compress_thread(
        self,
        thread_id: str,
        ai: dict[str, Any],
        system_prompt: str,
        *,
        plan: _CompressionPlan | None = None,
        actual_input_tokens: int | None = None,
    ) -> tuple[bool, str | None]:
        resolved = plan or await self.plan_compression(
            thread_id, ai, system_prompt, actual_input_tokens=actual_input_tokens
        )
        if resolved is None:
            return False, None

        blob = format_messages_for_summary(resolved.head)
        summarize_prompt = (
            f"请用中文将以下多轮对话压缩为一段摘要,{resolved.hint}。"
            "控制在 1200 字以内。\n\n----\n"
            f"{blob}"
        )
        llm = _build_llm(ai).bind(max_tokens=900)
        summary_msg = await llm.ainvoke([HumanMessage(content=summarize_prompt)])
        summary = chunk_text(summary_msg).strip()
        if not summary:
            return False, None

        bridge = HumanMessage(content=compress_summary_bridge_content(summary))
        await self.replace_thread_messages(
            thread_id, ai, system_prompt, [bridge, *resolved.tail]
        )
        return True, summary

    async def stream_turn(
        self,
        thread_id: str,
        user_content: str,
        ai: dict[str, Any],
        system_prompt: str,
    ) -> AsyncIterator[dict[str, Any]]:
        from app.core.context import set_current_thread_id

        thread = conversation_store.get_thread(thread_id)
        tools = resolve_chat_tools(thread)
        graph = self._get_cached_react_graph(thread, tools)
        sys_eff = resolve_effective_system_prompt(thread, system_prompt)

        set_current_thread_id(thread_id)
        try:
            async for event in self._stream_turn_events(
                thread_id,
                user_content,
                ai,
                system_prompt,
                thread,
                tools,
                graph,
                sys_eff,
            ):
                yield event
        finally:
            set_current_thread_id(None)

    async def _stream_turn_events(
        self,
        thread_id: str,
        user_content: str,
        ai: dict[str, Any],
        system_prompt: str,
        thread: dict[str, Any] | None,
        tools: list[Any],
        graph: Any,
        sys_eff: str,
    ) -> AsyncIterator[dict[str, Any]]:
        user_message = HumanMessage(content=user_content.strip())
        base_cfg = self.make_config(thread_id, ai, sys_eff)
        configurable = dict(base_cfg.get("configurable") or {})
        configurable["llm_system_prompt"] = sys_eff
        config = {**base_cfg, "configurable": configurable, "recursion_limit": 200}
        run_context = ChatInvokeContext(
            ai=ai,
            llm_system_prompt=sys_eff,
            system_prompt=sys_eff,
        )

        streamed_any = False
        reasoning_accum = ""
        usage_parts: list[dict[str, int]] = []
        cancelled = False
        try:
            async for ev in graph.astream_events(
                {"messages": [user_message]},
                config,
                context=run_context,
                version="v2",
            ):
                et = ev.get("event")
                if et == "on_chat_model_end":
                    parsed = usage_from_chat_model_end_event(ev)
                    if parsed:
                        usage_parts.append(parsed)
                    continue
                if et == "on_chat_model_stream":
                    chunk = _stream_event_data(ev).get("chunk")
                    if chunk is None:
                        continue
                    reasoning_piece = chunk_reasoning(chunk)
                    if reasoning_piece:
                        if (
                            reasoning_accum
                            and reasoning_piece.startswith(reasoning_accum)
                            and len(reasoning_piece) > len(reasoning_accum)
                        ):
                            reasoning_delta = reasoning_piece[len(reasoning_accum) :]
                            reasoning_accum = reasoning_piece
                        else:
                            reasoning_delta = reasoning_piece
                            reasoning_accum += reasoning_piece
                        if reasoning_delta:
                            streamed_any = True
                            yield {"reasoning_delta": reasoning_delta}
                    text = chunk_text(chunk)
                    if text:
                        streamed_any = True
                        yield {"delta": text}
                elif et == "on_tool_start":
                    name = str(ev.get("name") or "")
                    inp = _stream_event_data(ev).get("input")
                    yield {
                        "tool_start": {
                            "id": _stream_tool_call_id(ev),
                            "name": name,
                            "args": _tool_args_from_stream_input(inp),
                        }
                    }
                elif et == "on_tool_end":
                    name = str(ev.get("name") or "")
                    out = _stream_event_data(ev).get("output")
                    result = _coerce_tool_result(out)
                    ok = True
                    try:
                        parsed = json.loads(result)
                        if isinstance(parsed, dict) and (
                            parsed.get("ok") is False or parsed.get("error")
                        ):
                            ok = False
                    except json.JSONDecodeError:
                        pass
                    yield {
                        "tool_end": {
                            "id": _stream_tool_call_id(ev),
                            "name": name,
                            "result": _tool_result_preview(result),
                            "ok": ok,
                        }
                    }
                elif et == "on_tool_error":
                    err = _stream_event_data(ev).get("error")
                    msg = str(err) if err is not None else "工具执行失败"
                    yield {
                        "tool_end": {
                            "id": _stream_tool_call_id(ev),
                            "name": str(ev.get("name") or ""),
                            "result": _tool_result_preview(msg),
                            "ok": False,
                        }
                    }
        except (asyncio.CancelledError, GeneratorExit):
            cancelled = True
        except Exception as exc:
            _log.exception("chat stream graph error thread_id=%s", thread_id)
            yield {"error": str(exc) or "对话执行失败"}
            return

        yield {"usage": merge_usage_list(usage_parts) if usage_parts else empty_usage()}
        if cancelled:
            return

        messages = await self._snapshot_messages(thread_id, ai, system_prompt)
        last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
        if isinstance(last_ai, AIMessage) and not (last_ai.tool_calls or []):
            if not chunk_text(last_ai).strip() and not streamed_any:
                raise ValueError(
                    "模型未返回有效内容。请检查模型名、Base URL 与 API Key。"
                )

    async def delete_checkpoint_thread(self, thread_id: str) -> None:
        if self._saver is None:
            return
        await self._saver.adelete_thread(thread_id)


chat_runtime = ChatRuntime()