from __future__ import annotations
import json
import warnings
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
trim_messages,
)
from langchain_core.messages.utils import count_tokens_approximately
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph.state import CompiledStateGraph
from langgraph.runtime import Runtime
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.warnings import LangGraphDeprecatedSinceV10
from app.core.json_util import dumps_json
from app.services.token_usage import merge_usage, usage_from_message
_DEFAULT_SYSTEM_FALLBACK = (
"You are a helpful assistant. Reply in the same language as the user when appropriate."
)
COMPRESS_SUMMARY_PREFIX = "[此前对话已自动压缩为摘要]"
def compress_summary_bridge_content(summary: str) -> str:
return f"{COMPRESS_SUMMARY_PREFIX}\n{summary.strip()}"
def parse_compress_summary_content(text: str) -> str | None:
if not text.startswith(COMPRESS_SUMMARY_PREFIX):
return None
body = text[len(COMPRESS_SUMMARY_PREFIX) :].lstrip("\n").strip()
return body or None
def _as_dict(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if isinstance(value, str):
text = value.strip()
if text.startswith("{"):
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {}
def configurable_from_config(cfg: Any) -> dict[str, Any]:
"""LangGraph RunnableConfig.configurable 有时被序列化为非 dict,避免 conf.get 崩溃。"""
if not isinstance(cfg, dict):
return {}
return _as_dict(cfg.get("configurable"))
def ai_config_from_configurable(conf: dict[str, Any]) -> dict[str, Any]:
ai = conf.get("ai")
return _as_dict(ai) if ai is not None else {}
@dataclass
class ChatInvokeContext:
"""graph.astream_events / ainvoke 的 run context,供动态 model 读取 AI 配置。"""
ai: dict[str, Any] = field(default_factory=dict)
llm_system_prompt: str = ""
system_prompt: str = ""
def to_configurable(self) -> dict[str, Any]:
return {
"ai": self.ai,
"llm_system_prompt": self.llm_system_prompt,
"system_prompt": self.system_prompt or self.llm_system_prompt,
}
def _resolve_run_configurable(
*,
config: RunnableConfig | None = None,
runtime: Runtime[ChatInvokeContext] | None = None,
) -> dict[str, Any]:
"""从 RunnableConfig 或 Runtime.context 取 configurable,避免 get_config()(Py3.10 异步下不可用)。"""
if config is not None:
conf = configurable_from_config(config)
if conf:
return conf
ctx = getattr(runtime, "context", None) if runtime is not None else None
if isinstance(ctx, ChatInvokeContext):
return ctx.to_configurable()
return {}
def _aimessage_has_tool_calls(msg: AIMessage) -> bool:
if msg.tool_calls:
return True
extra = msg.additional_kwargs.get("tool_calls")
return bool(extra)
def repair_messages_for_thinking_model(messages: list[BaseMessage]) -> list[BaseMessage]:
"""
DeepSeek 等 thinking 模式:带 tool_calls 的 assistant 消息必须带 reasoning_content 字段。
旧 checkpoint / 流式合并遗漏时补空字符串(API 实测接受),避免 400。
"""
out: list[BaseMessage] = []
for msg in messages:
if isinstance(msg, AIMessage) and _aimessage_has_tool_calls(msg):
if "reasoning_content" not in msg.additional_kwargs:
msg = msg.model_copy(
update={
"additional_kwargs": {
**msg.additional_kwargs,
"reasoning_content": "",
}
}
)
out.append(msg)
return out
class ThinkingChatOpenAI(ChatOpenAI):
"""
兼容 OpenAI 兼容网关上的「思考 / thinking」模型(如 DeepSeek v4 等)。
响应里的 reasoning_content 必须写回 AIMessage,并在下一轮请求中原样带回。
"""
@staticmethod
def _reasoning_for_request(msg: AIMessage, msg_dict: dict[str, Any]) -> str | None:
if "reasoning_content" in msg.additional_kwargs:
rc = msg.additional_kwargs["reasoning_content"]
return rc if isinstance(rc, str) else str(rc)
if _aimessage_has_tool_calls(msg) or msg_dict.get("tool_calls"):
return ""
return None
def _inject_reasoning_into_payload(
self, payload: dict[str, Any], messages: list[BaseMessage]
) -> dict[str, Any]:
msg_dicts = payload.get("messages")
if not msg_dicts or len(msg_dicts) != len(messages):
return payload
for msg_dict, msg in zip(msg_dicts, messages, strict=False):
if not isinstance(msg, AIMessage):
continue
rc = self._reasoning_for_request(msg, msg_dict)
if rc is not None:
msg_dict["reasoning_content"] = rc
return payload
def _get_request_payload(
self,
messages: list[BaseMessage],
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
payload = super()._get_request_payload(messages, stop=stop, **kwargs)
return self._inject_reasoning_into_payload(payload, messages)
def _create_chat_result(
self,
response: Any,
generation_info: dict[str, Any] | None = None,
) -> ChatResult:
result = super()._create_chat_result(response, generation_info)
response_dict = (
response
if isinstance(response, dict)
else response.model_dump(
exclude={"choices": {"__all__": {"message": {"parsed"}}}}
)
)
for gen, choice in zip(
result.generations,
response_dict.get("choices") or [],
strict=False,
):
msg_dict = choice.get("message") or {}
rc = msg_dict.get("reasoning_content")
if rc is not None:
gen.message.additional_kwargs["reasoning_content"] = rc
return result
def _convert_chunk_to_generation_chunk(
self,
chunk: dict[str, Any],
default_chunk_class: type,
base_generation_info: dict[str, Any] | None,
) -> ChatGenerationChunk | None:
gen = super()._convert_chunk_to_generation_chunk(
chunk, default_chunk_class, base_generation_info
)
if gen is None:
return gen
nested = chunk.get("chunk")
if isinstance(nested, dict):
choices = chunk.get("choices") or nested.get("choices") or []
else:
choices = chunk.get("choices") or []
if not choices:
return gen
choice = choices[0]
delta = choice.get("delta") or {}
rc = delta.get("reasoning_content")
if rc is None and isinstance(choice.get("message"), dict):
rc = choice["message"].get("reasoning_content")
if rc is None:
return gen
msg = gen.message
if not isinstance(msg, AIMessage):
return gen
prev = msg.additional_kwargs.get("reasoning_content") or ""
if isinstance(rc, str):
msg.additional_kwargs["reasoning_content"] = prev + rc
else:
msg.additional_kwargs["reasoning_content"] = rc
return gen
def _build_llm(ai: dict[str, Any]) -> ThinkingChatOpenAI:
api_key = (ai.get("api_key") or "").strip()
if not api_key:
raise ValueError("未配置 API Key,请先在系统设置中保存 AI 设置")
kwargs: dict[str, Any] = {
"model": ai.get("model") or "gpt-4o-mini",
"temperature": float(ai.get("temperature", 0.7)),
"max_tokens": int(ai.get("max_tokens", 8192)),
"api_key": api_key,
"streaming": True,
}
base_url = (ai.get("base_url") or "").strip()
if base_url:
kwargs["base_url"] = base_url
return ThinkingChatOpenAI(**kwargs)
def compute_max_input_tokens(ai: dict[str, Any]) -> int:
context = int(ai.get("context_length") or 128000)
max_out = int(ai.get("max_tokens") or 8192)
return max(4096, int(context * 0.85) - max_out)
def compute_compress_threshold_tokens(ai: dict[str, Any]) -> int:
"""历史(含 system)估算 token 达到该值即触发摘要压缩:配置里「上下文长度」的 75%。"""
context = int(ai.get("context_length") or 128000)
return max(4096, int(context * 0.75))
def estimate_messages_tokens(messages: list[BaseMessage], system_prompt: str) -> int:
"""与 trim_messages 一致的近似计数:system + 历史消息。"""
system = SystemMessage(content=system_prompt or _DEFAULT_SYSTEM_FALLBACK)
return int(count_tokens_approximately([system, *messages]))
def effective_tokens_for_compression(
estimated: int,
actual_input_tokens: int | None = None,
) -> int:
"""压缩判定取估算与本轮 API 真实 input 的较大值。"""
actual = int(actual_input_tokens or 0)
if actual <= 0:
return estimated
return max(estimated, actual)
def prepare_messages_for_llm(
history: list[BaseMessage],
system_prompt: str,
ai: dict[str, Any],
) -> list[BaseMessage]:
"""在调用模型前裁剪上下文;不截断单条消息,避免破坏 tool_calls / ToolMessage 配对。"""
history = repair_messages_for_thinking_model(history)
system = SystemMessage(content=system_prompt or _DEFAULT_SYSTEM_FALLBACK)
combined: list[BaseMessage] = [system, *history]
return trim_messages(
combined,
max_tokens=compute_max_input_tokens(ai),
token_counter="approximate",
strategy="last",
allow_partial=False,
include_system=True,
start_on="human",
)
def reasoning_text(message: BaseMessage) -> str:
"""AIMessage 上累积的 reasoning_content(思考模型)。"""
if not isinstance(message, AIMessage):
return ""
rc = (getattr(message, "additional_kwargs", None) or {}).get("reasoning_content")
if rc is None:
return ""
return str(rc)
def chunk_reasoning(chunk: BaseMessage) -> str:
"""流式 chunk 上的 reasoning_content 增量。"""
ak = getattr(chunk, "additional_kwargs", None) or {}
if not isinstance(ak, dict):
return ""
rc = ak.get("reasoning_content")
if rc is None:
return ""
return str(rc)
def chunk_text(chunk: BaseMessage) -> str:
"""OpenAI Chat 风格:content 为 str 或 list[{type,text}, ...]。"""
content = getattr(chunk, "content", "") or ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict):
parts.append(str(block.get("text") or block.get("content") or ""))
return "".join(parts)
return str(content) if content else ""
def _normalize_tool_args(raw: Any) -> dict[str, Any]:
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
text = raw.strip()
if not text:
return {}
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {"input": raw}
return {}
def _parse_one_tool_call(call: Any) -> tuple[str, str, dict[str, Any]]:
if isinstance(call, dict):
name = str(call.get("name") or "")
cid = str(call.get("id") or "").strip()
args = _normalize_tool_args(call.get("args"))
if not args and call.get("arguments") is not None:
args = _normalize_tool_args(call.get("arguments"))
return name, cid, args
name = str(getattr(call, "name", "") or "")
cid = str(getattr(call, "id", "") or "").strip()
args = _normalize_tool_args(getattr(call, "args", {}))
return name, cid, args
def _parse_raw_openai_tool_calls(extra: Any) -> list[dict[str, Any]]:
if not isinstance(extra, list):
return []
out: list[dict[str, Any]] = []
for raw in extra:
if not isinstance(raw, dict):
continue
fn = raw.get("function") if isinstance(raw.get("function"), dict) else {}
name = str(fn.get("name") or raw.get("name") or "")
cid = str(raw.get("id") or "").strip()
args = _normalize_tool_args(fn.get("arguments"))
if not args and raw.get("args") is not None:
args = _normalize_tool_args(raw.get("args"))
out.append({"id": cid, "name": name, "args": args})
return out
def extract_tool_calls(message: AIMessage) -> list[dict[str, Any]]:
"""AIMessage.tool_calls → 统一 dict 列表;兼容 additional_kwargs 中的原始 tool_calls。"""
seen: set[str] = set()
out: list[dict[str, Any]] = []
def _append(call: Any) -> None:
name, cid, args = _parse_one_tool_call(call)
key = cid or f"__anon_{len(out)}__"
if key in seen:
return
seen.add(key)
out.append({"id": cid, "name": name, "args": args})
for call in getattr(message, "tool_calls", None) or []:
_append(call)
extra = (getattr(message, "additional_kwargs", None) or {}).get("tool_calls")
for call in _parse_raw_openai_tool_calls(extra):
_append(call)
return out
def _tool_result_ok(result: str) -> bool:
ok = True
try:
parsed = json.loads(result)
if isinstance(parsed, dict) and parsed.get("ok") is False:
ok = False
except json.JSONDecodeError:
pass
return ok
def _tool_dict(name: str, cid: str, args: dict[str, Any], result: str) -> dict[str, Any]:
return {
"id": cid,
"name": name,
"args": args,
"result": result,
"status": "done",
"ok": _tool_result_ok(result),
}
def _collect_tool_results(messages: list[BaseMessage], start: int) -> tuple[dict[str, str], list[str], int]:
results_by_id: dict[str, str] = {}
ordered: list[str] = []
j = start
n = len(messages)
while j < n and isinstance(messages[j], ToolMessage):
tm = messages[j]
cid = str(getattr(tm, "tool_call_id", "") or "")
body = chunk_text(tm)
ordered.append(body)
if cid:
results_by_id[cid] = body
j += 1
return results_by_id, ordered, j
def _tool_result_for_call(
cid: str,
idx: int,
results_by_id: dict[str, str],
ordered_results: list[str],
) -> str:
if cid and cid in results_by_id:
return results_by_id[cid]
if idx < len(ordered_results):
return ordered_results[idx]
return results_by_id.get(cid, "")
def _append_reasoning_segment(entry: dict[str, Any], reasoning: str) -> None:
"""每轮思考单独成段,避免多轮思考挤在同一组件。"""
reasoning = (reasoning or "").strip()
if not reasoning:
return
entry.setdefault("segments", [])
segs: list[dict[str, Any]] = entry["segments"]
prev_all = (entry.get("reasoning") or "").strip()
entry["reasoning"] = f"{prev_all}\n\n{reasoning}".strip() if prev_all else reasoning
if segs and segs[-1].get("type") == "reasoning":
prev = (str(segs[-1].get("content") or "")).strip()
segs[-1]["content"] = f"{prev}\n\n{reasoning}".strip() if prev else reasoning
else:
segs.append({"type": "reasoning", "content": reasoning})
def _append_text_segment(entry: dict[str, Any], text: str) -> None:
text = (text or "").strip()
if not text:
return
entry.setdefault("segments", [])
entry["segments"].append({"type": "text", "content": text})
prev = (entry.get("content") or "").strip()
entry["content"] = f"{prev}\n\n{text}".strip() if prev else text
def _merge_assistant_text(entry: dict[str, Any], text: str, reasoning: str = "") -> None:
_append_reasoning_segment(entry, reasoning)
_append_text_segment(entry, text)
def _attach_usage(entry: dict[str, Any], message: AIMessage) -> None:
usage = usage_from_message(message)
if not usage:
return
if entry.get("usage"):
entry["usage"] = merge_usage(entry["usage"], usage)
else:
entry["usage"] = usage
def _assistant_entry_with_tools(
text: str,
reasoning: str,
tools_block: list[dict[str, Any]],
) -> dict[str, Any]:
entry: dict[str, Any] = {
"role": "assistant",
"tools": tools_block,
"segments": [],
}
_append_reasoning_segment(entry, reasoning)
for tool in tools_block:
entry["segments"].append({"type": "tool", "tool": dict(tool)})
_append_text_segment(entry, text)
return entry
def messages_to_items(messages: list[BaseMessage]) -> list[dict[str, Any]]:
"""转为前端可渲染结构;assistant 用 segments 保持正文与工具的时间顺序。"""
items: list[dict[str, Any]] = []
i = 0
n = len(messages)
while i < n:
message = messages[i]
if isinstance(message, HumanMessage):
text = chunk_text(message).strip()
if text:
summary_body = parse_compress_summary_content(text)
if summary_body is not None:
items.append({"role": "summary", "content": summary_body})
else:
items.append({"role": "user", "content": text})
i += 1
continue
if isinstance(message, AIMessage):
text = chunk_text(message).strip()
reasoning = reasoning_text(message).strip()
calls = extract_tool_calls(message)
if calls:
tools_block: list[dict[str, Any]] = []
results_by_id, ordered_results, j = _collect_tool_results(messages, i + 1)
for idx, call in enumerate(calls):
name, cid, args = _parse_one_tool_call(call)
tools_block.append(
_tool_dict(
name,
cid,
args,
_tool_result_for_call(cid, idx, results_by_id, ordered_results),
)
)
entry = _assistant_entry_with_tools(text, reasoning, tools_block)
_attach_usage(entry, message)
i = j
while i < n and isinstance(messages[i], AIMessage):
follow = messages[i]
follow_text = chunk_text(follow).strip()
follow_reasoning = reasoning_text(follow).strip()
follow_calls = extract_tool_calls(follow)
if follow_calls:
_append_reasoning_segment(entry, follow_reasoning)
results_by_id, ordered_results, j = _collect_tool_results(messages, i + 1)
base_idx = len(tools_block)
for off, call in enumerate(follow_calls):
name, cid, args = _parse_one_tool_call(call)
td = _tool_dict(
name,
cid,
args,
_tool_result_for_call(
cid, base_idx + off, results_by_id, ordered_results
),
)
tools_block.append(td)
entry.setdefault("segments", []).append(
{"type": "tool", "tool": td}
)
entry["tools"] = tools_block
i = j
continue
if follow_text or follow_reasoning:
_merge_assistant_text(entry, follow_text, follow_reasoning)
_attach_usage(entry, follow)
i += 1
continue
i += 1
continue
items.append(entry)
continue
if text or reasoning:
item: dict[str, Any] = {"role": "assistant", "segments": []}
_append_reasoning_segment(item, reasoning)
_append_text_segment(item, text)
_attach_usage(item, message)
items.append(item)
i += 1
continue
i += 1
return items
def make_chat_prompt():
"""
create_react_agent 的 prompt:从完整 state['messages'] 生成模型输入。
不写 llm_input_messages,避免 ReAct 循环里工具结果已进 checkpoint 但模型仍读到旧裁剪列表。
"""
def chat_prompt(
state: dict[str, Any],
*,
config: RunnableConfig,
runtime: Runtime[ChatInvokeContext] | None = None,
) -> list[BaseMessage]:
conf = _resolve_run_configurable(config=config, runtime=runtime)
ai = ai_config_from_configurable(conf)
llm_sys = str(conf.get("llm_system_prompt") or conf.get("system_prompt") or "")
raw = list(state.get("messages") or [])
return prepare_messages_for_llm(raw, llm_sys, ai)
return chat_prompt
def make_dynamic_chat_model(bound_tools: list[Any]):
"""create_react_agent 动态 model:仅返回 ChatOpenAI(+bind_tools);裁剪由 pre_model_hook 负责。"""
def resolve_llm(
state: dict[str, Any],
runtime: Runtime[ChatInvokeContext],
) -> Any:
del state
conf = _resolve_run_configurable(runtime=runtime)
ai = ai_config_from_configurable(conf)
llm = _build_llm(ai)
if bound_tools:
return llm.bind_tools(bound_tools)
return llm
return resolve_llm
def format_tool_error_for_model(exc: Exception) -> str:
"""工具未捕获异常时写入 ToolMessage,供模型读取并自行纠正。"""
payload: dict[str, Any] = {
"ok": False,
"error": str(exc) or type(exc).__name__,
"error_type": type(exc).__name__,
}
return dumps_json(payload)
def build_chat_react_agent(
tools: Sequence[Any],
checkpointer: BaseCheckpointSaver,
) -> CompiledStateGraph:
"""create_react_agent:prompt 裁剪输入;checkpoint 保留完整 messages;tools 节点循环。"""
tool_list = list(tools)
tools_node: ToolNode | list[Any] = (
ToolNode(tool_list, handle_tool_errors=format_tool_error_for_model)
if tool_list
else []
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", LangGraphDeprecatedSinceV10)
return create_react_agent(
make_dynamic_chat_model(tool_list),
tools_node,
prompt=make_chat_prompt(),
context_schema=ChatInvokeContext,
checkpointer=checkpointer,
version="v2",
name="compilot_chat_agent",
)