"""LangGraph ReAct 对话运行时(含历史压缩)。"""
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any
import aiosqlite
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
RemoveMessage,
ToolMessage,
)
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.graph.state import CompiledStateGraph
from app.core.json_util import dumps_json
from app.core.paths import CHECKPOINT_DB, ensure_data_dirs
from app.graph.chat_graph import (
ChatInvokeContext,
_as_dict,
_build_llm,
build_chat_react_agent,
chunk_reasoning,
chunk_text,
compress_summary_bridge_content,
compute_compress_threshold_tokens,
effective_tokens_for_compression,
estimate_messages_tokens,
extract_tool_calls,
messages_to_items,
)
from app.services import conversation_store
from app.services.chat_context import resolve_chat_tools, resolve_effective_system_prompt
from app.services.token_usage import empty_usage, merge_usage_list, usage_from_chat_model_end_event
_log = logging.getLogger(__name__)
_GRAPH_NODE = "agent"
_COMPRESS_KEEP_RAW_MESSAGES = 6
@dataclass(frozen=True)
class _CompressionPlan:
head: list[BaseMessage]
tail: list[BaseMessage]
hint: str
def _tool_result_preview(text: str, limit: int = 4000) -> str:
if len(text) <= limit:
return text
return text[: limit - 3] + "..."
def _stream_tool_call_id(ev: dict[str, Any]) -> str:
"""流式工具事件 ID:优先 AI 的 tool_call_id,与 checkpoint 中 ToolMessage 对齐。"""
data = _stream_event_data(ev)
inp = data.get("input")
if isinstance(inp, dict):
tcid = str(inp.get("tool_call_id") or "").strip()
if tcid:
return tcid
tcid = str(data.get("tool_call_id") or "").strip()
if tcid:
return tcid
return str(ev.get("run_id") or "")
def _stream_event_data(ev: dict[str, Any]) -> dict[str, Any]:
return _as_dict(ev.get("data"))
def _coerce_tool_result(out: Any) -> str:
if out is None:
return ""
if isinstance(out, str):
return out
if isinstance(out, (dict, list)):
return dumps_json(out)
content = getattr(out, "content", None)
if isinstance(content, str):
return content
if isinstance(content, (dict, list)):
return dumps_json(content)
return str(out)
def _tool_args_from_stream_input(inp: Any) -> dict[str, Any]:
if isinstance(inp, dict):
return {k: v for k, v in inp.items() if k != "tool_call_id"}
return {}
def _react_graph_cache_key(_thread: dict[str, Any] | None) -> str:
return "__plain__"
def _tools_bind_token(tools: list[Any]) -> str:
parts: list[str] = []
for t in tools:
fn = getattr(t, "func", None) or getattr(t, "coroutine", None)
parts.append(f"{getattr(t, 'name', '')}:{id(fn)}")
return "|".join(parts) if parts else "_none_"
def _is_valid_compression_tail(messages: list[BaseMessage], start: int) -> bool:
"""保留的 tail 必须是可发给模型的合法后缀(tool 与 tool_calls 成对)。"""
if start < 0 or start >= len(messages):
return False
if isinstance(messages[start], ToolMessage):
return False
i = start
n = len(messages)
while i < n:
msg = messages[i]
if isinstance(msg, HumanMessage):
i += 1
continue
if isinstance(msg, AIMessage):
calls = extract_tool_calls(msg)
i += 1
if not calls:
continue
required = {str(c.get("id") or "") for c in calls if c.get("id")}
seen: set[str] = set()
while i < n and isinstance(messages[i], ToolMessage):
cid = str(getattr(messages[i], "tool_call_id", "") or "")
if cid:
seen.add(cid)
i += 1
if required and not required.issubset(seen):
return False
continue
if isinstance(msg, ToolMessage):
return False
i += 1
return True
def _find_safe_compression_split(raw: list[BaseMessage], desired_start: int) -> int | None:
n = len(raw)
if n < 2:
return None
desired_start = max(1, min(desired_start, n - 1))
candidates: list[int] = []
for offset in range(n):
for delta in (offset, -offset):
candidate = desired_start + delta
if 1 <= candidate < n and candidate not in candidates:
candidates.append(candidate)
for start in candidates:
if _is_valid_compression_tail(raw, start):
return start
return None
def _split_messages_for_compression(
raw: list[BaseMessage],
*,
keep_raw: int = _COMPRESS_KEEP_RAW_MESSAGES,
) -> tuple[list[BaseMessage], list[BaseMessage]] | None:
"""拆成待压缩 head 与保留原文 tail;切分点避开 tool / tool_calls 中间。"""
if len(raw) < 2:
return None
tail_size = min(keep_raw, len(raw) - 1)
desired_start = len(raw) - tail_size
start = _find_safe_compression_split(raw, desired_start)
if start is None:
return None
head = raw[:start]
tail = raw[start:]
if not head:
return None
return head, tail
def _tool_call_names(message: AIMessage) -> list[str]:
names: list[str] = []
for call in message.tool_calls or []:
if isinstance(call, dict):
name = str(call.get("name") or "").strip()
else:
name = str(getattr(call, "name", "") or "").strip()
if name:
names.append(name)
return names
def format_messages_for_summary(messages: list[BaseMessage], *, max_each: int = 800) -> str:
lines: list[str] = []
for message in messages:
if isinstance(message, HumanMessage):
text = chunk_text(message).strip()
if text:
lines.append(f"用户: {text[:max_each]}")
elif isinstance(message, AIMessage):
tool_names = _tool_call_names(message)
if tool_names:
lines.append(f"助手[调用工具]: {', '.join(tool_names)}")
text = chunk_text(message).strip()
if text:
lines.append(f"助手: {text[:max_each]}")
elif isinstance(message, ToolMessage):
name = str(getattr(message, "name", None) or "tool")
body = str(message.content or "")[:max_each]
lines.append(f"工具[{name}]: {body}")
return "\n".join(lines)
class ChatRuntime:
def __init__(self) -> None:
self._conn: aiosqlite.Connection | None = None
self._saver: Any = None
self._react_graph_cache: dict[str, CompiledStateGraph] = {}
@property
def ready(self) -> bool:
return bool(self._react_graph_cache)
def _get_cached_react_graph(
self, thread: dict[str, Any] | None, tools: list[Any]
) -> CompiledStateGraph:
if self._saver is None:
raise RuntimeError("ChatRuntime 尚未初始化")
key = _react_graph_cache_key(thread)
token = _tools_bind_token(tools)
graph = self._react_graph_cache.get(key)
if graph is None or getattr(graph, "_cp_tools_token", None) != token:
graph = build_chat_react_agent(tools, self._saver)
graph._cp_tools_token = token
self._react_graph_cache[key] = graph
return graph
async def start(self) -> None:
if self._react_graph_cache:
return
ensure_data_dirs()
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
self._conn = await aiosqlite.connect(str(CHECKPOINT_DB))
self._saver = AsyncSqliteSaver(self._conn)
await self._saver.setup()
self._react_graph_cache["__plain__"] = build_chat_react_agent([], self._saver)
async def stop(self) -> None:
if self._conn is not None:
await self._conn.close()
self._conn = None
self._saver = None
self._react_graph_cache.clear()
@staticmethod
def make_config(thread_id: str, ai: dict[str, Any], system_prompt: str) -> dict[str, Any]:
return {
"configurable": {
"thread_id": thread_id,
"ai": ai,
"system_prompt": system_prompt,
}
}
async def _snapshot_messages(
self, thread_id: str, ai: dict[str, Any], system_prompt: str
) -> list[BaseMessage]:
thread = conversation_store.get_thread(thread_id)
tools = resolve_chat_tools(thread)
graph = self._get_cached_react_graph(thread, tools)
config = self.make_config(thread_id, ai, system_prompt)
snapshot = await graph.aget_state(config)
return list((snapshot.values or {}).get("messages") or [])
async def get_messages(
self, thread_id: str, ai: dict[str, Any], system_prompt: str
) -> list[dict[str, Any]]:
messages = await self._snapshot_messages(thread_id, ai, system_prompt)
return messages_to_items(messages)
async def replace_thread_messages(
self,
thread_id: str,
ai: dict[str, Any],
system_prompt: str,
new_messages: list[BaseMessage],
) -> None:
thread = conversation_store.get_thread(thread_id)
tools = resolve_chat_tools(thread)
graph = self._get_cached_react_graph(thread, tools)
config = self.make_config(thread_id, ai, system_prompt)
await graph.aupdate_state(
config,
{
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
]
},
as_node=_GRAPH_NODE,
)
async def plan_compression(
self,
thread_id: str,
ai: dict[str, Any],
system_prompt: str,
*,
actual_input_tokens: int | None = None,
) -> _CompressionPlan | None:
thread = conversation_store.get_thread(thread_id)
effective_prompt = resolve_effective_system_prompt(thread, system_prompt)
raw = await self._snapshot_messages(thread_id, ai, system_prompt)
if len(raw) < 2:
return None
threshold = compute_compress_threshold_tokens(ai)
estimated = estimate_messages_tokens(raw, effective_prompt)
tokens = effective_tokens_for_compression(estimated, actual_input_tokens)
if tokens < threshold:
return None
split = _split_messages_for_compression(raw)
if split is None:
return None
head, tail = split
blob = format_messages_for_summary(head)
if not blob.strip():
return None
hint = "保留对渗透测试任务有用的结论、待办、目标范围与关键发现,不要编造"
return _CompressionPlan(head=head, tail=tail, hint=hint)
async def maybe_compress_thread(
self,
thread_id: str,
ai: dict[str, Any],
system_prompt: str,
*,
plan: _CompressionPlan | None = None,
actual_input_tokens: int | None = None,
) -> tuple[bool, str | None]:
resolved = plan or await self.plan_compression(
thread_id, ai, system_prompt, actual_input_tokens=actual_input_tokens
)
if resolved is None:
return False, None
blob = format_messages_for_summary(resolved.head)
summarize_prompt = (
f"请用中文将以下多轮对话压缩为一段摘要,{resolved.hint}。"
"控制在 1200 字以内。\n\n----\n"
f"{blob}"
)
llm = _build_llm(ai).bind(max_tokens=900)
summary_msg = await llm.ainvoke([HumanMessage(content=summarize_prompt)])
summary = chunk_text(summary_msg).strip()
if not summary:
return False, None
bridge = HumanMessage(content=compress_summary_bridge_content(summary))
await self.replace_thread_messages(
thread_id, ai, system_prompt, [bridge, *resolved.tail]
)
return True, summary
async def stream_turn(
self,
thread_id: str,
user_content: str,
ai: dict[str, Any],
system_prompt: str,
) -> AsyncIterator[dict[str, Any]]:
from app.core.context import set_current_thread_id
thread = conversation_store.get_thread(thread_id)
tools = resolve_chat_tools(thread)
graph = self._get_cached_react_graph(thread, tools)
sys_eff = resolve_effective_system_prompt(thread, system_prompt)
set_current_thread_id(thread_id)
try:
async for event in self._stream_turn_events(
thread_id,
user_content,
ai,
system_prompt,
thread,
tools,
graph,
sys_eff,
):
yield event
finally:
set_current_thread_id(None)
async def _stream_turn_events(
self,
thread_id: str,
user_content: str,
ai: dict[str, Any],
system_prompt: str,
thread: dict[str, Any] | None,
tools: list[Any],
graph: Any,
sys_eff: str,
) -> AsyncIterator[dict[str, Any]]:
user_message = HumanMessage(content=user_content.strip())
base_cfg = self.make_config(thread_id, ai, sys_eff)
configurable = dict(base_cfg.get("configurable") or {})
configurable["llm_system_prompt"] = sys_eff
config = {**base_cfg, "configurable": configurable, "recursion_limit": 200}
run_context = ChatInvokeContext(
ai=ai,
llm_system_prompt=sys_eff,
system_prompt=sys_eff,
)
streamed_any = False
reasoning_accum = ""
usage_parts: list[dict[str, int]] = []
cancelled = False
try:
async for ev in graph.astream_events(
{"messages": [user_message]},
config,
context=run_context,
version="v2",
):
et = ev.get("event")
if et == "on_chat_model_end":
parsed = usage_from_chat_model_end_event(ev)
if parsed:
usage_parts.append(parsed)
continue
if et == "on_chat_model_stream":
chunk = _stream_event_data(ev).get("chunk")
if chunk is None:
continue
reasoning_piece = chunk_reasoning(chunk)
if reasoning_piece:
if (
reasoning_accum
and reasoning_piece.startswith(reasoning_accum)
and len(reasoning_piece) > len(reasoning_accum)
):
reasoning_delta = reasoning_piece[len(reasoning_accum) :]
reasoning_accum = reasoning_piece
else:
reasoning_delta = reasoning_piece
reasoning_accum += reasoning_piece
if reasoning_delta:
streamed_any = True
yield {"reasoning_delta": reasoning_delta}
text = chunk_text(chunk)
if text:
streamed_any = True
yield {"delta": text}
elif et == "on_tool_start":
name = str(ev.get("name") or "")
inp = _stream_event_data(ev).get("input")
yield {
"tool_start": {
"id": _stream_tool_call_id(ev),
"name": name,
"args": _tool_args_from_stream_input(inp),
}
}
elif et == "on_tool_end":
name = str(ev.get("name") or "")
out = _stream_event_data(ev).get("output")
result = _coerce_tool_result(out)
ok = True
try:
parsed = json.loads(result)
if isinstance(parsed, dict) and (
parsed.get("ok") is False or parsed.get("error")
):
ok = False
except json.JSONDecodeError:
pass
yield {
"tool_end": {
"id": _stream_tool_call_id(ev),
"name": name,
"result": _tool_result_preview(result),
"ok": ok,
}
}
elif et == "on_tool_error":
err = _stream_event_data(ev).get("error")
msg = str(err) if err is not None else "工具执行失败"
yield {
"tool_end": {
"id": _stream_tool_call_id(ev),
"name": str(ev.get("name") or ""),
"result": _tool_result_preview(msg),
"ok": False,
}
}
except (asyncio.CancelledError, GeneratorExit):
cancelled = True
except Exception as exc:
_log.exception("chat stream graph error thread_id=%s", thread_id)
yield {"error": str(exc) or "对话执行失败"}
return
yield {"usage": merge_usage_list(usage_parts) if usage_parts else empty_usage()}
if cancelled:
return
messages = await self._snapshot_messages(thread_id, ai, system_prompt)
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if isinstance(last_ai, AIMessage) and not (last_ai.tool_calls or []):
if not chunk_text(last_ai).strip() and not streamed_any:
raise ValueError(
"模型未返回有效内容。请检查模型名、Base URL 与 API Key。"
)
async def delete_checkpoint_thread(self, thread_id: str) -> None:
if self._saver is None:
return
await self._saver.adelete_thread(thread_id)
chat_runtime = ChatRuntime()