import asyncio
import copy
import inspect
import json
import logging
import os
import re
import time
import uuid
from dataclasses import dataclass
from typing import Sequence, Any
import json_repair
from openjiuwen.core.foundation.llm.schema.message import (
UserMessage,
SystemMessage,
AssistantMessage,
ToolMessage,
UsageMetadata,
)
from openjiuwen.core.foundation.llm.schema.message_chunk import AssistantMessageChunk
from pydantic import BaseModel
from openjiuwen_deepsearch.common.common_constants import MAX_LLM_RESP_LENGTH
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.config.config import Config
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Message
from openjiuwen_deepsearch.utils.common_utils.stream_utils import get_current_time, MessageType, StreamEvent
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName, NodeId
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import session_context, cancel_context
from openjiuwen_deepsearch.utils.log_utils.log_common import session_id_ctx
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.utils.log_utils.log_metrics import metrics_logger, TIME_LOGGER_TAG
logger = logging.getLogger(__name__)
DEFAULT_AGENT_NAME = "AI"
_ALLOWED_AGENT_NAMES = frozenset({DEFAULT_AGENT_NAME, *(item.value for item in AgentLlmName)})
def format_llm_log_correlation_suffix() -> str:
"""Suffix for LLM logs: per-run directory basename (e.g. result_<conversation_id>) and absolute path."""
try:
rt = session_context.get()
except LookupError:
return ""
if rt is None:
return ""
ld = rt.get_global_state("log_dir")
if not ld:
cfg = rt.get_global_state("config") or {}
if isinstance(cfg, dict):
ld = cfg.get("log_dir") or ""
if not ld:
return ""
if LogManager.is_sensitive():
return " result_run=***"
base = os.path.basename(os.path.normpath(str(ld)))
return f" result_run={base} log_dir={os.path.abspath(str(ld))}"
_DEEPSEARCH_NODE_IDS = frozenset({
NodeId.INITIAL_STATE.value,
NodeId.FIND_ACTION_SPACE.value,
NodeId.RUN_ACTION.value,
NodeId.VALIDATE_NEW_STATE.value,
NodeId.TOOL.value,
NodeId.SIMPLE_REACT_SEARCH.value,
})
_WORKFLOW_LLM_USAGE: dict[str, dict[str, Any]] = {}
_USAGE_ONLY_PARSER_PATCHES: dict[int, dict[str, Any]] = {}
_AGENT_LLM_TIMEOUT_MAX_SECONDS = 3600
def _clamp_agent_llm_timeout(timeout: Any) -> int:
"""将调用方提供的 agent LLM timeout 限制在服务端策略上限内。
Args:
timeout: 原始超时值。
Returns:
int: 裁剪后的超时值。
"""
try:
parsed_timeout = int(timeout)
except (TypeError, ValueError):
return 0
if parsed_timeout < 0:
return 0
return min(parsed_timeout, _AGENT_LLM_TIMEOUT_MAX_SECONDS)
def normalize_agent_llm_timeouts(value: Any) -> dict[str, int]:
"""规范化按 agent 配置的 LLM 总超时字典。
缺少 ``default`` 时仍保持未启用语义;一旦启用:
- 0 值表示关闭业务层墙钟超时,回退到底层 SDK 的 ServiceConfig.llm_timeout 控制
- 极大值会被裁剪到上限,防止绕过底层超时控制
Args:
value: 原始超时配置。
Returns:
dict[str, int]: 规范化后的超时配置;缺少 ``default`` 时返回空字典表示未启用。
"""
if not isinstance(value, dict) or not value or "default" not in value:
return {}
normalized_value: dict[str, int] = {}
for agent_key, timeout in value.items():
normalized_value[agent_key] = _clamp_agent_llm_timeout(timeout)
return normalized_value
@dataclass(frozen=True)
class ResolvedAgentTimeout:
"""描述一次 agent LLM 总超时的解析结果。
Attributes:
timeout: 最终命中的超时时间,单位秒。
matched_by: 命中来源,取值为 ``agent_name``、``node_key`` 或 ``default``。
matched_key: 实际命中的配置 key。
resolved_node_key: 从 ``agent_name`` 解析出的节点级 key;无法解析时为 ``None``。
"""
timeout: int
matched_by: str
matched_key: str
resolved_node_key: str | None
@dataclass(frozen=True)
class _ConsumeLlmStreamRequest:
"""封装流消费阶段所需的上下文参数。
Attributes:
llm: LLM 对象。
stream_kwargs: 传给 ``llm.stream`` 的参数。
can_write_stream: 是否允许写自定义流。
need_stream_out: 是否需要输出流式消息。
session: 当前 session。
stream_id: 当前流 ID。
stream_meta: 附加流元数据。
agent_name: 当前 agent 名称。
"""
llm: Any
stream_kwargs: dict[str, Any]
can_write_stream: bool
need_stream_out: bool
session: Any
stream_id: str | None
stream_meta: dict[str, Any] | None
agent_name: str
def _normalize_agent_name(agent_name: Any) -> str:
"""标准化 agent_name 字段。
Args:
agent_name (Any): 原始 agent_name 值。
Returns:
str: 标准化后的 agent_name;为空时返回 "unknown"。
"""
if not isinstance(agent_name, str):
return "unknown"
normalized_name = agent_name.strip()
return normalized_name if normalized_name else "unknown"
def _validate_invoke_agent_name(agent_name: Any) -> str:
"""校验并规范化 LLM 调用入口的 agent_name。
Args:
agent_name: 调用方传入的原始 agent_name。
Returns:
规范化后的 agent_name;None 或空白字符串返回默认值。
Raises:
CustomValueException: agent_name 既不是空值,也不是默认值或 AgentLlmName 中定义的值时抛出。
"""
if agent_name is None:
return DEFAULT_AGENT_NAME
if not isinstance(agent_name, str):
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.code,
message=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.errmsg.format(param="agent_name"),
)
normalized_name = agent_name.strip()
if not normalized_name:
return DEFAULT_AGENT_NAME
if normalized_name in _ALLOWED_AGENT_NAMES:
return normalized_name
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.code,
message=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.errmsg.format(param="agent_name"),
)
def _resolve_node_agent_key(agent_name: str) -> str | None:
"""按最长前缀从 agent_name 中解析节点级 key。
Args:
agent_name: 原始 agent_name。
Returns:
str | None: 解析得到的 ``NodeId.value``;无法匹配时返回 ``None``。
"""
normalized_agent_name = _normalize_agent_name(agent_name)
for node_key in sorted((item.value for item in NodeId), key=len, reverse=True):
if normalized_agent_name.startswith(node_key):
return node_key
return None
def _resolve_agent_llm_timeout(agent_name: str, session: Any = None) -> ResolvedAgentTimeout | None:
"""解析当前 agent 应使用的 wall-clock timeout。
Args:
agent_name: 原始 agent_name。
session: 当前 session 对象。
Returns:
ResolvedAgentTimeout | None: 命中的超时解析结果;未配置时返回 ``None``。
"""
if session is None:
return None
try:
timeout_config = session.get_global_state("config.agent_llm_timeouts")
except Exception:
return None
timeout_config = normalize_agent_llm_timeouts(timeout_config)
if not timeout_config:
return None
normalized_agent_name = _normalize_agent_name(agent_name)
if normalized_agent_name in timeout_config:
return ResolvedAgentTimeout(
timeout=timeout_config[normalized_agent_name],
matched_by="agent_name",
matched_key=normalized_agent_name,
resolved_node_key=_resolve_node_agent_key(normalized_agent_name),
)
resolved_node_key = _resolve_node_agent_key(normalized_agent_name)
if resolved_node_key and resolved_node_key in timeout_config:
return ResolvedAgentTimeout(
timeout=timeout_config[resolved_node_key],
matched_by="node_key",
matched_key=resolved_node_key,
resolved_node_key=resolved_node_key,
)
if "default" in timeout_config:
return ResolvedAgentTimeout(
timeout=timeout_config["default"],
matched_by="default",
matched_key="default",
resolved_node_key=resolved_node_key,
)
return None
def _build_empty_agent_name_usage(agent_name: str) -> dict[str, Any]:
"""构造单个 agent_name 的空 token 统计结构。
Args:
agent_name (str): 调用方标识。
Returns:
dict[str, Any]: 单 agent 的空统计结构。
"""
return {
"agent_name": _normalize_agent_name(agent_name),
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"llm_call_count": 0,
}
def _merge_agent_name_usage_list(agent_usage_list: Any) -> list[dict[str, Any]]:
"""合并并规范化 agent_name 级 token 统计列表。
Args:
agent_usage_list (Any): 待处理的 agent 统计列表。
Returns:
list[dict[str, Any]]: 去重并聚合后的统计列表。
"""
if not isinstance(agent_usage_list, list):
return []
merged_usage: dict[str, dict[str, Any]] = {}
for usage_item in agent_usage_list:
if not isinstance(usage_item, dict):
continue
normalized_name = _normalize_agent_name(usage_item.get("agent_name"))
current_usage = merged_usage.setdefault(normalized_name, _build_empty_agent_name_usage(normalized_name))
current_usage["input_tokens"] += _to_non_negative_int(usage_item.get("input_tokens", 0))
current_usage["output_tokens"] += _to_non_negative_int(usage_item.get("output_tokens", 0))
current_usage["total_tokens"] += _to_non_negative_int(usage_item.get("total_tokens", 0))
current_usage["llm_call_count"] += _to_non_negative_int(usage_item.get("llm_call_count", 0))
return list(merged_usage.values())
def _build_empty_workflow_llm_usage() -> dict[str, Any]:
"""构造空的 workflow 级 token 统计结构。
Returns:
dict[str, Any]: 空统计结构,包含总量和 agent_name 统计字段。
"""
return {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"llm_call_count": 0,
"agent_name_token_usage": [],
}
def normalize_workflow_llm_usage(usage: Any) -> dict[str, Any]:
"""标准化 workflow 级 token 统计结构。
Args:
usage (Any): 待标准化的数据,可为 dict 或任意类型。
Returns:
dict[str, Any]: 规范后的统计结构,包含总量统计和 agent_name 维度统计。
"""
if not isinstance(usage, dict):
return _build_empty_workflow_llm_usage()
return {
"input_tokens": _to_non_negative_int(usage.get("input_tokens", 0)),
"output_tokens": _to_non_negative_int(usage.get("output_tokens", 0)),
"total_tokens": _to_non_negative_int(usage.get("total_tokens", 0)),
"llm_call_count": _to_non_negative_int(usage.get("llm_call_count", 0)),
"agent_name_token_usage": _merge_agent_name_usage_list(usage.get("agent_name_token_usage", [])),
}
def is_workflow_llm_usage_empty(usage: dict[str, Any]) -> bool:
"""判断 workflow 级 token 统计是否为空。
Args:
usage (dict[str, Any]): token 统计结构。
Returns:
bool: 全字段为 0 返回 True。
"""
normalized_usage = normalize_workflow_llm_usage(usage)
return (
normalized_usage["input_tokens"] == 0
and normalized_usage["output_tokens"] == 0
and normalized_usage["total_tokens"] == 0
and normalized_usage["llm_call_count"] == 0
and len(normalized_usage["agent_name_token_usage"]) == 0
)
def _ensure_workflow_llm_usage_initialized(session_id: str, session: Any = None) -> None:
"""确保指定 workflow 的本地累计状态已初始化。
该方法用于跨进程恢复场景:当本地内存没有 session_id 对应累计时,尝试从
session 全局状态中的 `search_context.final_result.workflow_llm_token_usage` 恢复。
Args:
session_id (str): workflow 对应会话 ID。
session (Any): 当前会话对象,需支持 get_global_state 方法。
"""
if not session_id or session_id == "-":
return
current_usage = _WORKFLOW_LLM_USAGE.get(session_id)
if current_usage is not None and not is_workflow_llm_usage_empty(current_usage):
return
restored_usage = _build_empty_workflow_llm_usage()
if session is not None:
try:
snapshot = session.get_global_state("search_context.final_result.workflow_llm_token_usage")
restored_usage = normalize_workflow_llm_usage(snapshot)
except Exception:
restored_usage = _build_empty_workflow_llm_usage()
_WORKFLOW_LLM_USAGE[session_id] = restored_usage
def get_effective_workflow_llm_usage(session_id: str, session: Any = None) -> dict[str, Any]:
"""获取当前 workflow 的有效 token 汇总快照。
优先返回本地内存中的累计统计;当本地为空时,回退读取 session 中持久化的
`search_context.final_result.workflow_llm_token_usage`,以覆盖 HITL 恢复等场景。
Args:
session_id (str): workflow 对应会话 ID。
session (Any): 当前会话对象,需支持 get_global_state 方法;为空时仅返回本地累计。
Returns:
dict[str, Any]: 当前有效统计快照。
"""
local_usage = normalize_workflow_llm_usage(get_workflow_llm_usage(session_id))
if session is None or not is_workflow_llm_usage_empty(local_usage):
return local_usage
try:
persisted_usage = normalize_workflow_llm_usage(
session.get_global_state("search_context.final_result.workflow_llm_token_usage")
)
except Exception:
persisted_usage = _build_empty_workflow_llm_usage()
if not is_workflow_llm_usage_empty(persisted_usage):
return persisted_usage
return local_usage
def save_workflow_llm_usage_to_session(session: Any, session_id: str) -> dict[str, Any]:
"""将 workflow 级 token 累计落盘到 session 全局状态。
Args:
session (Any): 当前会话对象,需支持 update_global_state 方法。
session_id (str): workflow 对应会话 ID。
Returns:
dict[str, Any]: 当前累计统计快照。
"""
usage = get_effective_workflow_llm_usage(session_id=session_id, session=session)
if session is None:
return usage
try:
session.update_global_state({"search_context.final_result.workflow_llm_token_usage": usage})
except Exception as e:
logger.debug("Exception when updating session's global state: %s", e, exc_info=True)
return usage
def add_workflow_llm_usage(
session_id: str,
input_tokens: int,
output_tokens: int,
total_tokens: int,
agent_name: str = "",
) -> None:
"""累加指定 workflow 的 LLM token 消耗。
Args:
session_id (str): workflow 对应会话 ID。
input_tokens (int): 本次调用输入 token 数。
output_tokens (int): 本次调用输出 token 数。
total_tokens (int): 本次调用总 token 数。
agent_name (str): 本次调用的 agent 名称。
"""
if not session_id or session_id == "-":
return
usage = _WORKFLOW_LLM_USAGE.setdefault(session_id, _build_empty_workflow_llm_usage())
usage["input_tokens"] += _to_non_negative_int(input_tokens)
usage["output_tokens"] += _to_non_negative_int(output_tokens)
usage["total_tokens"] += _to_non_negative_int(total_tokens)
usage["llm_call_count"] += 1
if agent_name:
normalized_name = _normalize_agent_name(agent_name)
agent_usage_list = usage.setdefault("agent_name_token_usage", [])
if not isinstance(agent_usage_list, list):
agent_usage_list = []
usage["agent_name_token_usage"] = agent_usage_list
target_usage = None
for usage_item in agent_usage_list:
if isinstance(usage_item, dict) and usage_item.get("agent_name") == normalized_name:
target_usage = usage_item
break
if target_usage is None:
target_usage = _build_empty_agent_name_usage(normalized_name)
agent_usage_list.append(target_usage)
target_usage["input_tokens"] += _to_non_negative_int(input_tokens)
target_usage["output_tokens"] += _to_non_negative_int(output_tokens)
target_usage["total_tokens"] += _to_non_negative_int(total_tokens)
target_usage["llm_call_count"] += 1
def get_workflow_llm_usage(session_id: str) -> dict[str, Any]:
"""获取指定 workflow 的 LLM token 汇总信息。
Args:
session_id (str): workflow 对应会话 ID。
Returns:
dict[str, Any]: 汇总统计;若不存在返回全 0 结构。
"""
if not session_id or session_id == "-":
return _build_empty_workflow_llm_usage()
usage = _WORKFLOW_LLM_USAGE.get(session_id)
if usage is None:
return _build_empty_workflow_llm_usage()
return copy.deepcopy(usage)
def pop_workflow_llm_usage(session_id: str) -> dict[str, Any]:
"""弹出并返回指定 workflow 的 LLM token 汇总信息。
Args:
session_id (str): workflow 对应会话 ID。
Returns:
dict[str, Any]: 汇总统计;若不存在返回全 0 结构。
"""
if not session_id or session_id == "-":
return _build_empty_workflow_llm_usage()
usage = _WORKFLOW_LLM_USAGE.pop(session_id, None)
if usage is None:
return _build_empty_workflow_llm_usage()
return usage
def _to_non_negative_int(value: Any, default: int = 0) -> int:
"""将任意数值安全转换为非负整数。
Args:
value (Any): 待转换值,支持 int/float/str 等可转为数字的类型。
default (int): 转换失败时的默认值。
Returns:
int: 非负整数;转换失败时返回 default。
"""
try:
return max(int(value), 0)
except (TypeError, ValueError):
return default
def _to_dict_safe(value: Any) -> dict[str, Any]:
"""将对象安全转换为字典。
Args:
value (Any): 任意对象,可能为 ``dict``、Pydantic 模型或普通对象。
Returns:
dict[str, Any]: 可用字典;转换失败时返回空字典。
"""
if isinstance(value, dict):
return value
if isinstance(value, BaseModel):
return value.model_dump()
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
try:
dumped = model_dump()
except Exception:
return {}
return dumped if isinstance(dumped, dict) else {}
value_dict = getattr(value, "__dict__", None)
if isinstance(value_dict, dict):
return value_dict
return {}
def _extract_usage_tokens(usage_payload: Any) -> tuple[int, int, int]:
"""从多种 usage 结构中提取输入、输出、总 token 数。
Args:
usage_payload (Any): usage 对象,可能为 dict、Pydantic 模型或 SDK 对象。
Returns:
tuple[int, int, int]: ``(input_tokens, output_tokens, total_tokens)``。
"""
usage = _to_dict_safe(usage_payload)
token_usage = usage.get("token_usage")
if isinstance(token_usage, dict):
merged_usage = dict(token_usage)
merged_usage.update(usage)
usage = merged_usage
input_tokens = _to_non_negative_int(
usage.get("input_tokens", usage.get("prompt_tokens", usage.get("prompt_token_count", 0)))
)
if input_tokens == 0:
input_tokens = _to_non_negative_int(usage.get("prompt_tokens", usage.get("prompt_token_count", 0)))
output_tokens = _to_non_negative_int(
usage.get("output_tokens", usage.get("completion_tokens", usage.get("completion_token_count", 0)))
)
if output_tokens == 0:
output_tokens = _to_non_negative_int(usage.get("completion_tokens", usage.get("completion_token_count", 0)))
total_tokens = usage.get("total_tokens", usage.get("total_token_count"))
if total_tokens is None:
total_tokens = input_tokens + output_tokens
total_tokens = _to_non_negative_int(total_tokens, default=input_tokens + output_tokens)
return input_tokens, output_tokens, total_tokens
def _is_llm_stats_enabled() -> bool:
"""判断当前调用是否开启 LLM 调用统计。
优先级:
1. 当前会话中的 `config.stats_info_llm`(按单次 workflow 生效)。
2. 全局默认配置 `Config().agent_config.stats_info_llm`。
Returns:
bool: 是否开启 LLM 调用统计。
"""
try:
session = session_context.get()
if session is not None:
session_flag = session.get_global_state("config.stats_info_llm")
if session_flag is not None:
return bool(session_flag)
except Exception as e:
logger.debug("Exception when checking whether llm stats is enabled: %s", e, exc_info=True)
return bool(Config().agent_config.stats_info_llm)
def _raise_if_cancelled():
"""
检查 cancel_context 中的取消事件,如果已设置则抛出 CancelledError。
此函数在 LLM 调用的关键路径(llm_astream / ainvoke_llm_with_stats)中被调用,
用于及时响应外部取消请求,中断正在进行的 LLM 流式/非流式调用。
"""
cancel_event = cancel_context.get()
if cancel_event and cancel_event.is_set():
logger.info("LLM call cancelled via cancel_event")
raise asyncio.CancelledError("cancelled")
def messages_to_json(messages: Sequence[Any] | Message) -> str:
"""Dump message to json string."""
result = []
if messages is None:
return ""
if isinstance(messages, Message):
result = messages.model_dump()
else:
for msg in messages:
if isinstance(msg, dict):
result.append(msg)
elif isinstance(msg, Message):
result.append(msg.model_dump())
else:
result.append(str(msg))
if not LogManager.is_sensitive():
logger.error(f"error message type: {msg}")
else:
logger.error(f"error message type.")
return json.dumps(result, ensure_ascii=False, indent=4)
def normalize_json_output(input_data: str) -> str:
"""
规范化 JSON 输出
Args:
input_data: 可能包含 JSON 的字符串内容
Returns:
str: 规范化的 JSON 字符串,如果不是 JSON, 则为原始内容
"""
processed = input_data.strip()
json_signals = ('{', '[', '```json', '```ts')
if not any(indicator in processed for indicator in json_signals[:2]) and not any(
marker in processed for marker in json_signals[2:]):
return processed
code_blocks = {
'prefixes': ('```json', '```ts'),
'suffix': '```'
}
for prefix in code_blocks['prefixes']:
if processed.startswith(prefix):
processed = processed[len(prefix):].lstrip('\n')
if processed.endswith(code_blocks['suffix']):
processed = processed[:-len(code_blocks['suffix'])].rstrip('\n')
try:
reconstructed = json_repair.loads(processed)
return json.dumps(reconstructed, ensure_ascii=False)
except Exception as error:
if not LogManager.is_sensitive():
logger.error(f"JSON normalization error: {error}")
else:
logger.error(f"JSON normalization error.")
return input_data.strip()
def _extract_json(text: str) -> str:
return re.sub(r"^```(?:json)?\n|\n```$", "", text.strip())
def _single_provider_error_detail(
exc: BaseException, *, max_response_text: int = 16_000
) -> dict[str, Any]:
"""Same fields the OpenAI Python SDK exposes on API errors (cf. openai.APIStatusError).
Mirrors ``test.test._error_detail`` so logs match direct ``OpenAI()`` calls.
"""
out: dict[str, Any] = {
"exception_type": type(exc).__name__,
"exception_str": str(exc),
}
body = getattr(exc, "body", None)
if body is not None:
out["exception_body"] = body
response = getattr(exc, "response", None)
if response is not None:
status = getattr(response, "status_code", None)
if status is not None:
out["response_status"] = status
try:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
frag = text.strip()
if len(frag) > max_response_text:
frag = frag[:max_response_text] + "...(truncated)"
out["response_text"] = frag
except Exception as e:
logger.debug("Exception when extracting response text: %s", e, exc_info=True)
for attr in ("request_id", "code", "param", "type"):
val = getattr(exc, attr, None)
if val is not None:
out[f"exception_{attr}"] = val
return out
def _format_llm_invoke_exception(
exc: BaseException,
*,
max_response_text: int = 16_000,
max_formatted_len: int = 48_000,
) -> str:
"""JSON matching normal OpenAI client errors: structured body + optional cause chain."""
chain: list[dict[str, Any]] = []
seen: set[int] = set()
cur: BaseException | None = exc
for _ in range(8):
if cur is None or id(cur) in seen:
break
seen.add(id(cur))
chain.append(
_single_provider_error_detail(cur, max_response_text=max_response_text)
)
cur = cur.__cause__
if not chain:
payload: dict[str, Any] = {
"exception_type": type(exc).__name__,
"exception_str": str(exc),
}
elif len(chain) == 1:
payload = chain[0]
else:
payload = {"error_chain": chain}
raw = json.dumps(payload, ensure_ascii=False, indent=2, default=str)
if len(raw) > max_formatted_len:
raw = raw[:max_formatted_len] + "\n...(truncated)"
return raw
def llm_error_chain_blob(exc: BaseException | None, *, max_depth: int = 8) -> str:
"""Concatenate ``str(exc)`` along ``__cause__`` and the root's ``__context__`` chain.
Used for substring heuristics (e.g. context-limit phrases) when the wrapper message
only says "status code is 400".
"""
if exc is None:
return ""
parts: list[str] = []
seen: set[int] = set()
def walk_chain(start: BaseException | None) -> None:
cur = start
for _ in range(max_depth):
if cur is None or id(cur) in seen:
break
seen.add(id(cur))
parts.append(str(cur))
cur = cur.__cause__
walk_chain(exc)
ctx = exc.__context__
if ctx is not None and id(ctx) not in seen:
walk_chain(ctx)
return " ".join(parts)
async def _call_model_method(method, primary_kwargs: dict):
"""Invoke sync or async model methods without losing awaitables."""
if inspect.iscoroutinefunction(method):
return await method(**primary_kwargs)
return await asyncio.to_thread(method, **primary_kwargs)
def _extract_usage_payload_from_stream_chunk(raw_chunk: Any) -> Any:
"""从流式原始 chunk 中提取 usage 载荷。
Args:
raw_chunk (Any): 模型客户端收到的原始流式 chunk。
Returns:
Any: usage 结构;不存在时返回 ``None``。
"""
if raw_chunk is None:
return None
if hasattr(raw_chunk, "usage"):
usage_value = getattr(raw_chunk, "usage")
return usage_value if usage_value else None
if isinstance(raw_chunk, dict):
usage_value = raw_chunk.get("usage")
return usage_value if usage_value else None
decoded = None
if isinstance(raw_chunk, (bytes, bytearray)):
decoded = bytes(raw_chunk).decode("utf-8", errors="ignore").strip()
elif isinstance(raw_chunk, str):
decoded = raw_chunk.strip()
if not decoded:
return None
if decoded.startswith("data: "):
decoded = decoded[6:]
if decoded == "[DONE]":
return None
try:
parsed_payload = json.loads(decoded)
except Exception:
return None
if not isinstance(parsed_payload, dict):
return None
usage_value = parsed_payload.get("usage")
return usage_value if usage_value else None
def _build_usage_only_chunk(raw_chunk: Any, model_name: str) -> AssistantMessageChunk | None:
"""根据 usage-only chunk 构造可合并的 AssistantMessageChunk。
Args:
raw_chunk (Any): 原始流式 chunk。
model_name (str): 模型名称。
Returns:
AssistantMessageChunk | None: 可用于累计 usage 的空内容 chunk;无有效 usage 时返回 ``None``。
"""
usage_payload = _extract_usage_payload_from_stream_chunk(raw_chunk)
if usage_payload is None:
return None
input_tokens, output_tokens, total_tokens = _extract_usage_tokens(usage_payload)
if input_tokens == 0 and output_tokens == 0 and total_tokens == 0:
return None
return AssistantMessageChunk(
content="",
reasoning_content=None,
tool_calls=None,
usage_metadata=UsageMetadata(
model_name=model_name,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
),
finish_reason="null",
)
def _install_usage_only_chunk_parser(llm_model: Any) -> Any:
"""在共享 client 上安装 usage-only chunk 解析器,并返回恢复函数。
由于同一个 LLM client 可能被多个协程并发复用,这里使用引用计数保证:
1. 第一个调用负责安装补偿 parser。
2. 后续嵌套调用仅增加引用计数,不重复覆盖 parser。
3. 只有最后一个调用结束时才恢复原始 parser。
Args:
llm_model (Any): openjiuwen 模型对象,内部包含 ``_client``。
Returns:
Any: 可调用恢复函数;不可安装时返回 ``None``。
"""
client = getattr(llm_model, "_client", None)
if client is None:
return None
original_parser = getattr(client, "_parse_stream_chunk", None)
if not callable(original_parser):
return None
patch_key = id(client)
patch_state = _USAGE_ONLY_PARSER_PATCHES.get(patch_key)
if patch_state is None:
model_name = getattr(getattr(llm_model, "model_config", None), "model_name", "")
def _patched_parser(raw_chunk: Any):
parsed_chunk = original_parser(raw_chunk)
if parsed_chunk is not None:
return parsed_chunk
return _build_usage_only_chunk(raw_chunk, model_name=model_name)
setattr(client, "_parse_stream_chunk", _patched_parser)
patch_state = {
"ref_count": 0,
"original_parser": original_parser,
"patched_parser": _patched_parser,
}
_USAGE_ONLY_PARSER_PATCHES[patch_key] = patch_state
patch_state["ref_count"] += 1
def _restore_parser():
current_state = _USAGE_ONLY_PARSER_PATCHES.get(patch_key)
if current_state is None:
return
current_state["ref_count"] = max(int(current_state.get("ref_count", 0)) - 1, 0)
if current_state["ref_count"] > 0:
return
try:
setattr(client, "_parse_stream_chunk", current_state["original_parser"])
except Exception:
pass
finally:
_USAGE_ONLY_PARSER_PATCHES.pop(patch_key, None)
return _restore_parser
def _resolve_stream_options(llm_model: Any, need_include_usage: bool) -> dict | None:
"""合并模型已有 stream_options,并按需注入 include_usage。
Args:
llm_model (Any): LLM 模型实例,可能包含 model_config.stream_options。
need_include_usage (bool): 是否强制注入 include_usage。
Returns:
dict | None: 合并后的 stream_options;若无可用配置则返回 ``None``。
"""
merged_options: dict = {}
model_config = getattr(llm_model, "model_config", None)
if model_config is not None:
model_dump = getattr(model_config, "model_dump", None)
if callable(model_dump):
model_config_dict = model_dump()
existed_options = model_config_dict.get("stream_options")
if isinstance(existed_options, dict):
merged_options.update(existed_options)
if need_include_usage:
merged_options["include_usage"] = True
return merged_options or None
async def _consume_llm_stream(
request: _ConsumeLlmStreamRequest,
) -> Any:
"""消费流式 LLM 输出并聚合完整结果。
Args:
request: 流消费阶段的具名上下文参数。
Returns:
Any: 聚合后的完整响应块。
"""
full_chunk = None
async for chunk in request.llm.stream(**request.stream_kwargs):
_raise_if_cancelled()
if full_chunk is None:
full_chunk = chunk
else:
full_chunk += chunk
if len(full_chunk.content) >= MAX_LLM_RESP_LENGTH:
logger.warning(
"[llm_astream] llm response is too long, "
"truncate to %s characters", MAX_LLM_RESP_LENGTH
)
full_chunk.content = full_chunk.content[:MAX_LLM_RESP_LENGTH]
break
chunk_content = getattr(chunk, "content", "")
if request.can_write_stream and request.need_stream_out and chunk_content:
payload = {
"message_id": request.stream_id,
"agent": request.agent_name,
"content": chunk_content,
"message_type": MessageType.MESSAGE_CHUNK.value,
"event": StreamEvent.MESSAGE.value,
"created_time": get_current_time(),
}
if request.stream_meta:
payload.update(dict(request.stream_meta))
await request.session.write_custom_stream(payload)
return full_chunk
async def llm_astream(*args, **kwargs):
"""以流式方式调用 LLM 并返回完整响应。
Args:
llm (Any): LLM 实例。
messages (list): LLM 输入消息列表。
model_name (str): 模型名称。
agent_name (str): 当前调用方名称,用于流输出元信息。
tools (Any): 本次调用绑定的工具列表。
need_stream_out (bool): 是否将增量内容写入会话流。
stream_meta (dict | None): 附加流事件字段。
stream_options (dict | None): 传入模型 SDK 的流式配置。
Returns:
Any: 聚合后的完整 LLM 响应块。
"""
llm = kwargs.get("llm", args[0] if len(args) > 0 else None)
messages = kwargs.get("messages", args[1] if len(args) > 1 else None)
model_name = kwargs.get("model_name", args[2] if len(args) > 2 else None)
agent_name = kwargs.get("agent_name", args[3] if len(args) > 3 else None)
tools = kwargs.get("tools", None)
need_stream_out = kwargs.get("need_stream_out", False)
stream_meta = kwargs.get("stream_meta", None)
stream_options = kwargs.get("stream_options", None)
_raise_if_cancelled()
full_chunk = None
can_write_stream = True
session = None
try:
session = session_context.get()
if session is None:
can_write_stream = False
logger.debug(f"session_context not set, can not write to stream")
except LookupError:
can_write_stream = False
logger.debug(f"session_context not set, can not write to stream")
def _make_payload(message_id: str, event: str, content: str = "") -> dict:
payload = {
"message_id": message_id,
"agent": agent_name,
"content": content,
"message_type": MessageType.MESSAGE_CHUNK.value,
"event": event,
"created_time": get_current_time()
}
if stream_meta:
payload.update(dict(stream_meta))
return payload
stream_id = None
if can_write_stream and need_stream_out:
stream_id = str(uuid.uuid4())
await session.write_custom_stream(_make_payload(stream_id, StreamEvent.START.value, ""))
restore_usage_parser = None
resolved_timeout = None
if isinstance(stream_options, dict) and bool(stream_options.get("include_usage")):
restore_usage_parser = _install_usage_only_chunk_parser(llm)
try:
stream_kwargs = {
"messages": messages,
"model": model_name,
"tools": tools,
}
if stream_options is not None:
stream_kwargs["stream_options"] = stream_options
resolved_timeout = _resolve_agent_llm_timeout(agent_name=agent_name, session=session)
if resolved_timeout is not None and resolved_timeout.timeout > 0:
logger.info(
"[llm_astream] applying wall-clock timeout agent_name=%s "
"node_key=%s matched_by=%s matched_key=%s timeout=%s",
agent_name,
resolved_timeout.resolved_node_key,
resolved_timeout.matched_by,
resolved_timeout.matched_key,
resolved_timeout.timeout,
)
full_chunk = await asyncio.wait_for(
_consume_llm_stream(_ConsumeLlmStreamRequest(
llm=llm,
stream_kwargs=stream_kwargs,
can_write_stream=can_write_stream,
need_stream_out=need_stream_out,
session=session,
stream_id=stream_id,
stream_meta=stream_meta,
agent_name=agent_name,
)),
timeout=resolved_timeout.timeout,
)
else:
full_chunk = await _consume_llm_stream(_ConsumeLlmStreamRequest(
llm=llm,
stream_kwargs=stream_kwargs,
can_write_stream=can_write_stream,
need_stream_out=need_stream_out,
session=session,
stream_id=stream_id,
stream_meta=stream_meta,
agent_name=agent_name,
))
except asyncio.TimeoutError as exc:
if can_write_stream and need_stream_out:
await session.write_custom_stream(_make_payload(stream_id, StreamEvent.DONE.value, ""))
logger.warning(
"[llm_astream] wall-clock timeout agent_name=%s node_key=%s matched_by=%s matched_key=%s timeout=%s",
agent_name,
resolved_timeout.resolved_node_key if resolved_timeout else None,
resolved_timeout.matched_by if resolved_timeout else None,
resolved_timeout.matched_key if resolved_timeout else None,
resolved_timeout.timeout if resolved_timeout else None,
)
raise CustomValueException(
error_code=StatusCode.LLM_WALL_CLOCK_TIMEOUT.code,
message=StatusCode.LLM_WALL_CLOCK_TIMEOUT.errmsg.format(
timeout=resolved_timeout.timeout,
agent_name=agent_name,
matched_by=resolved_timeout.matched_by,
matched_key=resolved_timeout.matched_key,
node_key=resolved_timeout.resolved_node_key,
),
) from exc
except Exception as e:
if can_write_stream and need_stream_out:
await session.write_custom_stream(_make_payload(stream_id, StreamEvent.DONE.value, ""))
raise e
finally:
if callable(restore_usage_parser):
restore_usage_parser()
if can_write_stream and need_stream_out:
await session.write_custom_stream(_make_payload(stream_id, StreamEvent.DONE.value, ""))
if full_chunk is None:
logger.error(f"[llm_astream] llm response is None")
raise CustomValueException(
error_code=StatusCode.LLM_RESPONSE_NONE.code,
message=StatusCode.LLM_RESPONSE_NONE.errmsg)
return full_chunk
def _parse_invoke_llm_args(args, kwargs) -> dict:
return {
"llm": kwargs.get("llm", args[0] if len(args) > 0 else None),
"messages": kwargs.get("messages", args[1] if len(args) > 1 else None),
"llm_type": kwargs.get("llm_type", "basic"),
"agent_name": kwargs.get("agent_name", DEFAULT_AGENT_NAME),
"schema": kwargs.get("schema", None),
"tools": kwargs.get("tools", None),
"need_stream_out": kwargs.get("need_stream_out", False),
"stream_meta": kwargs.get("stream_meta", None),
}
async def ainvoke_llm_with_stats(*args, **kwargs):
"""调用 LLM 并按配置记录调用统计。
Args:
llm (dict): LLM 配置字典,包含 model 与 model_name。
messages (list): 输入消息列表。
llm_type (str): LLM 类型标识,默认 "basic"。
agent_name (str): 调用节点或方法名。
schema (BaseModel | None): 结构化输出模型;为空时返回统一 dict。
tools (Any): 本次调用绑定工具。
need_stream_out (bool): 是否将模型输出写入会话流。
stream_meta (dict | None): 附加流事件字段。
Returns:
dict | BaseModel: schema 不为空时返回 schema 实例,否则返回统一后的 dict。
"""
invoke_args = _parse_invoke_llm_args(args, kwargs)
llm = invoke_args["llm"]
messages = invoke_args["messages"]
llm_type = invoke_args["llm_type"]
agent_name = _validate_invoke_agent_name(invoke_args["agent_name"])
schema = invoke_args["schema"]
tools = invoke_args["tools"]
need_stream_out = invoke_args["need_stream_out"]
stream_meta = invoke_args["stream_meta"]
_raise_if_cancelled()
if not llm:
raise CustomValueException(
error_code=StatusCode.LLM_INSTANCE_NONE_ERROR.code,
message=StatusCode.LLM_INSTANCE_NONE_ERROR.errmsg)
stats_info_llm = _is_llm_stats_enabled()
session_id = session_id_ctx.get()
current_session = None
if stats_info_llm:
try:
current_session = session_context.get()
except Exception:
current_session = None
_ensure_workflow_llm_usage_initialized(session_id=session_id, session=current_session)
if not llm_type.strip():
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.code,
message=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.errmsg.format(param="llm_type"))
model_name = llm.get("model_name", "")
if not model_name:
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.code,
message=StatusCode.PARAM_CHECK_ERROR_COMMON_INVALID.errmsg.format(param="model_name"))
start = None
if stats_info_llm:
start = time.time()
messages = transfer_to_jiuwen_messages(messages)
llm_model = llm.get("model", None)
if llm_model is None:
raise CustomValueException(
error_code=StatusCode.LLM_INSTANCE_NONE_ERROR.code,
message=StatusCode.LLM_INSTANCE_NONE_ERROR.errmsg)
resolved_stream_options = (
_resolve_stream_options(llm_model=llm_model, need_include_usage=True)
if stats_info_llm
else None
)
if agent_name in _DEEPSEARCH_NODE_IDS:
processed_messages = []
for m in messages:
if isinstance(m, (SystemMessage, UserMessage, AssistantMessage, ToolMessage)):
name = getattr(m, "name", None)
if name == "" or name is None:
msg_dict = m.model_dump()
msg_dict.pop("name", None)
if isinstance(m, SystemMessage):
processed_messages.append(SystemMessage(**msg_dict))
elif isinstance(m, UserMessage):
processed_messages.append(UserMessage(**msg_dict))
elif isinstance(m, AssistantMessage):
processed_messages.append(AssistantMessage(**msg_dict))
elif isinstance(m, ToolMessage):
processed_messages.append(ToolMessage(**msg_dict))
else:
processed_messages.append(m)
else:
processed_messages.append(m)
messages = processed_messages
invoke_kw = {
"model": model_name,
"messages": messages,
"tools": tools,
}
try:
if hasattr(llm_model, "invoke"):
response = await _call_model_method(llm_model.invoke, invoke_kw)
elif hasattr(llm_model, "_ainvoke"):
ainvoke_method = getattr(llm_model, "_ainvoke", None)
if ainvoke_method:
response = await _call_model_method(ainvoke_method, invoke_kw)
elif hasattr(llm_model, "ainvoke"):
response = await _call_model_method(llm_model.ainvoke, invoke_kw)
else:
response = await llm_astream(
llm=llm_model,
messages=messages,
model_name=model_name,
agent_name=agent_name,
tools=tools,
need_stream_out=False,
stream_meta=stream_meta,
stream_options=resolved_stream_options,
)
except Exception as e:
detail = _format_llm_invoke_exception(e)
corr = format_llm_log_correlation_suffix()
logger.warning(
"[ainvoke_llm_with_stats] LLM invoke failed%s agent=%s: %s",
corr,
agent_name or "-",
"*" if LogManager.is_sensitive() else detail,
exc_info=not LogManager.is_sensitive(),
)
raise CustomValueException(
StatusCode.LLM_CALL_FAILED.code,
StatusCode.LLM_CALL_FAILED.errmsg.format(e=detail),
) from e
else:
response = await llm_astream(
llm=llm_model,
messages=messages,
model_name=model_name,
agent_name=agent_name,
tools=tools,
need_stream_out=need_stream_out,
stream_meta=stream_meta,
stream_options=resolved_stream_options,
)
if stats_info_llm:
duration = time.time() - start
input_tokens, output_tokens, total_tokens = _extract_usage_tokens(response.usage_metadata)
llm_stat = {
"method_name": agent_name,
"duration": duration,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": total_tokens
}
add_workflow_llm_usage(
session_id=session_id,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
agent_name=agent_name,
)
metrics_logger.info(
f"{TIME_LOGGER_TAG} session_id: {session_id_ctx.get()} ------ [LLM CALL STATISTICS]: {llm_stat}"
)
response.content = _extract_json(response.content)
if schema is not None:
response = schema.model_validate_json(response.content)
return response
return _unify_responnse(response)
def _unify_responnse(response):
temp_response = response.model_dump()
new_response = copy.deepcopy(temp_response)
if temp_response.get("tool_calls"):
tool_calls = temp_response.get("tool_calls")
for idx, tool_call in enumerate(tool_calls):
func = tool_call.get("function")
if not tool_call.get("args") and func and func.get("arguments"):
arguments = normalize_json_output(func.get("arguments"))
new_response.get("tool_calls")[idx]["args"] = json.loads(arguments)
if func and func.get("name"):
new_response.get("tool_calls")[idx]["name"] = func.get("name")
new_response.get("tool_calls")[idx]["type"] = "function"
new_response.get("tool_calls")[idx].pop("index", None)
return new_response
def transfer_to_jiuwen_messages(origin_messages: list):
"""转换消息类型"""
output_messages = []
for message in origin_messages:
if isinstance(message, dict):
role = message.get("role", "")
content = message.get("content", "")
name = message.get("name", "")
if role == "system":
output_messages.append(SystemMessage(content=content, name=name))
elif role == "user":
output_messages.append(UserMessage(content=content, name=name))
elif role == "assistant":
raw_tcs = message.get("tool_calls", []) or []
fixed_tcs = []
for tc in raw_tcs:
if isinstance(tc, dict):
tc = dict(tc)
if tc.get("type") == "tool_call":
tc["type"] = "function"
elif not tc.get("type"):
tc["type"] = "function"
fixed_tcs.append(tc)
else:
fixed_tcs.append(tc)
output_messages.append(
AssistantMessage(
content=content,
name=name,
tool_calls=fixed_tcs,
usage_metadata=message.get("usage_metadata", None),
reasoning_content=message.get("reason_content", "")
)
)
elif role == "tool":
output_messages.append(
ToolMessage(content=content, name=name,
tool_call_id=message.get("tool_call_id", "") or f"call_{str(uuid.uuid4().hex[:22])}")
)
else:
logger.error(f"role:{role} not support")
elif isinstance(message, BaseModel):
output_messages.append(message)
else:
logger.error(f"message type:{type(message)} not support")
if not any(isinstance(message, UserMessage) for message in output_messages):
output_messages.append(UserMessage(content="."))
return output_messages
def record_llm_retry_log(*args, **kwargs):
"""Record the retry log of LLM."""
current_try = kwargs.get("current_try", args[0] if len(args) > 0 else 0)
max_retries = kwargs.get("max_retries", args[1] if len(args) > 1 else 3)
section_idx = kwargs.get("section_idx", args[2] if len(args) > 2 else None)
step_title = kwargs.get("step_title", args[3] if len(args) > 3 else None)
operation = kwargs.get("operation", args[4] if len(args) > 4 else None)
error = kwargs.get("error", args[5] if len(args) > 5 else None)
extra_info = kwargs.get("extra_info", args[6] if len(args) > 6 else None)
if LogManager.is_sensitive():
if current_try < max_retries:
msg = (f"section_idx: {section_idx} | "
f"Error when {operation} | "
f"retry , number of retries: {current_try} / {max_retries}")
logger.warning(f"{msg}")
else:
msg = (f"section_idx: {section_idx} | "
f"Error when {operation} | "
f"Failed to {operation}, the max retries have been reached, max retry : {max_retries}")
logger.error(f"{msg}")
else:
error_detail = f"{error}" if error else ""
if current_try < max_retries:
msg = (f"section_idx: {section_idx} | step title: {step_title} | "
f"Error when {operation}: {error_detail} | "
f"Extra Info: {extra_info} | "
f"retry , number of retries: {current_try} / {max_retries}")
logger.warning(msg, exc_info=error is not None)
else:
msg = (f"section_idx: {section_idx} | step title: {step_title} | "
f"Error when {operation}: {error_detail} | "
f"Extra Info: {extra_info} | "
f"Failed to {operation}, the max retries have been reached, max retry : {max_retries}")
logger.error(msg, exc_info=error is not None)