import asyncio
import base64
import copy
import json
import logging
import os
from pathlib import Path
import time
import uuid
from typing import Any, AsyncGenerator, Optional
from openjiuwen.core.application.workflow_agent.workflow_agent import (
WorkflowAgent as LegacyWorkflowAgent,
)
from openjiuwen.core.runner.runner import Runner
from openjiuwen.core.session.checkpointer import CheckpointerFactory
from openjiuwen.core.session.stream.base import CustomSchema, OutputSchema
from openjiuwen.core.single_agent.legacy.agent import WorkflowFactory
from openjiuwen.core.single_agent.legacy.config import WorkflowAgentConfig
from openjiuwen.core.single_agent.schema.agent_card import AgentCard
from openjiuwen.core.workflow import Workflow, WorkflowCard, WorkflowOutput
from pydantic import ValidationError
from openjiuwen_deepsearch.algorithm.prompts.template import get_prompt_section
from openjiuwen_deepsearch.algorithm.report_template.template_generator import TemplateGenerator
from openjiuwen_deepsearch.algorithm.search_agent.action_pool import ActionPool
from openjiuwen_deepsearch.algorithm.search_agent.deepsearch_agent import (
parse_and_validate_find_action_result,
parse_and_validate_init_state_result,
parse_and_validate_state_creation_result,
)
from openjiuwen_deepsearch.algorithm.search_nodes.utils import anonymize_config_for_logging
from openjiuwen_deepsearch.algorithm.search_nodes.llm_utils import _run_llm_via_ainvoke
from openjiuwen_deepsearch.algorithm.search_nodes.run_action import (
_parse_one_native_tool_call,
get_tool_definitions,
)
from openjiuwen_deepsearch.algorithm.search_nodes.tool_node import (
ExecuteToolConfig,
execute_tool,
format_tool_result_for_message,
)
from openjiuwen_deepsearch.algorithm.search_nodes.utils import (
SaveSearchFinalResultConfig,
Termination,
_save_and_return_search_final_result,
to_dict_safe,
to_json_safe,
)
from openjiuwen_deepsearch.algorithm.search_tools.retriever_tool import RetrieveTool
from openjiuwen_deepsearch.algorithm.search_tools.web_fetch_tool import WebFetch
from openjiuwen_deepsearch.algorithm.search_tools.web_search_tool import WebSearch
from openjiuwen_deepsearch.algorithm.user_feedback_processor.action_definitions import (
_is_report_feedback_payload,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.config.config import (
AgentConfig,
Config,
CustomLocalSearchConfig,
CustomWebSearchConfig,
LocalSearchEngineConfig,
MilvusConfig,
PerQuestionParams,
SearchWorkflowConfig,
WebSearchEngineConfig,
)
from openjiuwen_deepsearch.framework.openjiuwen.agent.base_node import init_router
from openjiuwen_deepsearch.framework.openjiuwen.agent.editor_team_manager_node import (
DependencyEditorTeamNode,
EditorTeamNode,
)
from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import (
DependencyOutlineInteractionNode,
DependencyOutlineNode,
EndNode,
IntentRecognitionNode,
EntryNode,
FeedbackHandlerNode,
FindActionSpaceNode,
GenerateQuestionsNode,
InitializeStateNode,
OutlineInteractionNode,
OutlineNode,
ReporterNode,
RunActionNode,
SearchEndNode,
SearchStartNode,
SourceTracerInferNode,
SourceTracerNode,
StartNode,
ToolNode,
UserFeedbackProcessorNode,
ValidateNewStateNode,
VLMChartGeneratorNode,
)
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import (
Action,
Result,
SearchFinalResult,
State,
)
from openjiuwen_deepsearch.framework.openjiuwen.core.workflow_agent import WorkflowControllerConfig
from openjiuwen_deepsearch.framework.openjiuwen.core.workflow_agent.workflow_agent import WorkflowAgent
from openjiuwen_deepsearch.framework.openjiuwen.llm.llm_adapter import LlmConfigCategory
from openjiuwen_deepsearch.framework.openjiuwen.tools import (
update_local_search_mapping,
update_web_search_mapping,
)
from openjiuwen_deepsearch.llm.llm_request_adapter import resolve_llm_thinking_enabled
from openjiuwen_deepsearch.llm.llm_wrapper import create_llm_obj
from openjiuwen_deepsearch.utils.common_utils.llm_utils import (
get_effective_workflow_llm_usage,
is_workflow_llm_usage_empty,
pop_workflow_llm_usage,
)
from openjiuwen_deepsearch.utils.common_utils.security_utils import ensure_safe_directory, zero_secret
from openjiuwen_deepsearch.utils.common_utils.stream_utils import (
MessageType,
StreamEvent,
get_current_time,
)
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName, NodeId
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import (
llm_context,
local_search_context,
session_context,
tool_context,
web_search_context,
)
from openjiuwen_deepsearch.utils.log_utils.log_common import session_id_ctx
from openjiuwen_deepsearch.utils.log_utils.log_interface import record_interface_log
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
from openjiuwen_deepsearch.utils.log_utils.log_metrics import TIME_LOGGER_TAG, metrics_logger
from openjiuwen_deepsearch.utils.rate_limiter_utils.qps_limiter import qps_rate_limiter
from openjiuwen_deepsearch.utils.run_telemetry import emit, emit_messages_updated
from openjiuwen_deepsearch.utils.validation_utils.field_validation import (
validate_agent_required_field,
)
from openjiuwen_deepsearch.utils.validation_utils.param_validation import (
validate_generate_template_params,
validate_run_agent_params,
)
def _build_retrieve_tool(milvus_cfg: MilvusConfig) -> RetrieveTool:
kwargs = {}
if milvus_cfg.retriever_class is not None:
kwargs["retriever_class"] = milvus_cfg.retriever_class
return RetrieveTool(
{
"milvus_host": milvus_cfg.milvus_host,
"milvus_port": milvus_cfg.milvus_port,
"database_name": milvus_cfg.database_name,
"collection_name": milvus_cfg.collection_name,
"embedder_model_name": milvus_cfg.embedder_model_name,
"embedder_api_key": milvus_cfg.embedder_api_key,
"embedder_base_url": milvus_cfg.embedder_base_url,
"embedder_timeout": milvus_cfg.embedder_timeout,
},
**kwargs,
)
logger = logging.getLogger(__name__)
def _redact_agent_config_for_workflow_inputs(agent_config: Any) -> dict:
"""Build a redacted copy of agent_config for workflow logging boundaries."""
return anonymize_config_for_logging(copy.deepcopy(to_dict_safe(agent_config)))
class BaseAgent:
"""
base agent: agent基类
"""
async def run(
self,
message: str,
conversation_id: str,
agent_config: dict,
report_template: str = "",
interrupt_feedback: str = "",
):
"""
运行agent的抽象方法
Args:
message (str): 入参query
conversation_id (str): 会话ID
agent_config (dict): agent的配置
report_template (str): 报告模板
interrupt_feedback (str): HITL的用户反馈信息
Returns:
Async generator that yields StreamEvent objects.
"""
raise CustomValueException(StatusCode.AGENT_RUN_NOT_SUPPORT.code, StatusCode.AGENT_RUN_NOT_SUPPORT.errmsg)
async def generate_template(self, file_name: str, file_stream: str, is_template: bool, agent_config: dict):
"""
生成报告模板的抽象方法
Args:
file_name (str): 文件名,包括后缀
file_stream (str): base64编码的文件内容
is_template (bool): 是否为模板文件(True:模板文件,False:从报告生成)
agent_config (dict): agent的配置
Returns:
dict: {"status" str, "template_content" str, "error_message" str}
"""
start_time = time.time()
success = False
response_info = {}
try:
validate_generate_template_params(file_name, file_stream, is_template)
validate_agent_required_field(agent_config)
result = await TemplateGenerator.generate_template(
file_name=file_name, file_stream=file_stream, is_template=is_template, agent_config=agent_config
)
success = result.get("status", "").lower() == "success"
response_info = {} if success else {"exception_info": result.get("error_message", "")}
return result
except Exception as e:
if LogManager.is_sensitive():
logger.error(f"[extract_template]")
else:
logger.error(f"[extract_template] {e}")
if LogManager.is_sensitive():
error_msg = "Error when generating template."
else:
error_msg = str(e)
response_info = {"exception_info": error_msg}
return {"status": "fail", "template_content": "", "error_message": error_msg}
finally:
duration_min = (time.time() - start_time) / 60
record_interface_log(
role="SVR",
session_id="-",
api_name="generate_template",
duration_min=duration_min,
success=success,
response_info=response_info,
)
class DeepresearchAgent(BaseAgent):
"""
Deepresearch agent: 生成报告 Agent,通用模型,并行执行任务,不带模板
"""
def __init__(self):
self.research_name = self._get_default_research_name()
self.version = "1"
self.agent = None
self.workflow_input_schema = {
"query": {
"type": "string",
},
"thread_id": {
"type": "string",
},
"conversation_id": {
"type": "string",
},
"report_template": {
"type": "string",
},
"interrupt_feedback": {
"type": "string",
},
"agent_config": {
"type": "object",
},
}
self.startnode_input_schema = {
"query": "${query}",
"thread_id": "${thread_id}",
"conversation_id": "${conversation_id}",
"report_template": "${report_template}",
"interrupt_feedback": "${interrupt_feedback}",
"agent_config": "${agent_config}",
}
self.research_workflow = None
self._create_research_workflow_agent()
def _get_default_research_name(self) -> str:
return "research_workflow"
@staticmethod
def _build_workflow_provider(builder, workflow_card: WorkflowCard):
"""构造可重复调用的 workflow provider。
Args:
builder: 用于创建 workflow 实例的可调用对象。
workflow_card: 当前 workflow 的卡片元信息。
Returns:
callable: 每次调用都返回新 workflow 实例、并附带 workflow card 元属性的 provider。
"""
def _provider():
return builder()
_provider.id = workflow_card.id
_provider.version = workflow_card.version
_provider.name = workflow_card.name
_provider.description = workflow_card.description
_provider.input_params = workflow_card.input_params
return _provider
@staticmethod
def _build_interrupt_message(thread_id: str, chunk: OutputSchema):
payload_id = getattr(getattr(chunk, "payload", None), "id", "")
interrupt_message = {
"conversation_id": thread_id,
"agent": payload_id,
"section_idx": getattr(chunk, "section_idx", "0"),
"plan_idx": getattr(chunk, "plan_idx", "0"),
"step_idx": getattr(chunk, "step_idx", "0"),
"message_id": str(uuid.uuid4()),
"role": "assistant",
"content": chunk.payload.value,
"message_type": MessageType.INTERRUPT.value,
"event": StreamEvent.WAITING_USER_INPUT.value,
"created_time": getattr(chunk, "created_time", ""),
}
if not LogManager.is_sensitive():
logger.debug("[OUTPUT] Interrupt event: %s", json.dumps(interrupt_message, ensure_ascii=False))
return json.dumps(interrupt_message, ensure_ascii=False)
@staticmethod
def _build_output_message(thread_id: str, chunk: CustomSchema):
output_message = {
"conversation_id": thread_id,
"section_idx": getattr(chunk, "section_idx", "0"),
"plan_idx": getattr(chunk, "plan_idx", "0"),
"step_idx": getattr(chunk, "step_idx", "0"),
"message_id": getattr(chunk, "message_id", ""),
"agent": getattr(chunk, "agent", "Default"),
"role": "assistant",
"content": getattr(chunk, "content", ""),
"message_type": getattr(chunk, "message_type", ""),
"event": getattr(chunk, "event", ""),
"created_time": getattr(chunk, "created_time", ""),
}
if hasattr(chunk, "finish_reason"):
output_message["finish_reason"] = getattr(chunk, "finish_reason")
if not LogManager.is_sensitive():
logger.debug("[OUTPUT] Message event: %s", json.dumps(output_message, ensure_ascii=False))
return json.dumps(output_message, ensure_ascii=False)
@staticmethod
async def _release_checkpointer_session(conversation_id: str):
"""显式释放 checkpointer 会话状态,防止分布式场景残留。"""
try:
checkpointer = CheckpointerFactory.get_checkpointer()
if not checkpointer:
return
release_result = checkpointer.release(conversation_id)
if hasattr(release_result, "__await__"):
await release_result
except Exception as e:
if not LogManager.is_sensitive():
logger.warning(f"[DeepResearchAgent.run] Failed to release checkpointer session: {e}")
else:
logger.warning("[DeepResearchAgent.run] Failed to release checkpointer session.")
@staticmethod
def _register_web_search_tool(custom_web: CustomWebSearchConfig, search_config: WebSearchEngineConfig):
"""注册网络搜索工具"""
search_engine_mapping = update_web_search_mapping(
custom_web.custom_web_search_file, custom_web.custom_web_search_func
)
web_engine_name = search_config.search_engine_name
if web_engine_name not in search_engine_mapping:
error_msg = f"Failed to register web engine: {web_engine_name}, engine is not found in the registry."
logger.error(f"[Tool Init] {error_msg}")
raise CustomValueException(StatusCode.WEB_SEARCH_INSTANCE_OBTAIN_ERROR.code, message=error_msg)
return web_engine_name, search_engine_mapping
@staticmethod
def _register_local_search_tool(custom_local: CustomLocalSearchConfig, search_config: LocalSearchEngineConfig):
"""注册本地搜索工具"""
local_engine_mapping = update_local_search_mapping(
custom_local.custom_local_search_file,
custom_local.custom_local_search_func,
)
engine_name = search_config.search_engine_name
if engine_name == "native":
if not search_config.knowledge_base_configs:
error_msg = "native local search requires knowledge_base_configs"
logger.error(f"[Tool Init] {error_msg}")
raise CustomValueException(
StatusCode.LOCAL_SEARCH_INSTANCE_OBTAIN_ERROR.code,
message=error_msg,
)
if engine_name not in local_engine_mapping:
error_msg = f"Failed to register local engine: {engine_name}, " f"engine is not found in the registry."
logger.error(f"[Tool Init] {error_msg}")
raise CustomValueException(
StatusCode.LOCAL_SEARCH_INSTANCE_OBTAIN_ERROR.code,
message=error_msg,
)
return engine_name, local_engine_mapping
@staticmethod
async def _aopen_local_search_engines():
for name, engine in (local_search_context.get() or {}).items():
if hasattr(engine, "aopen"):
try:
await engine.aopen()
logger.debug("LocalSearch engine [%s] opened.", name)
except Exception as e:
logger.warning("Failed to open local search engine [%s]: %s", name, e)
@staticmethod
async def _aclose_local_search_engines():
for name, engine in (local_search_context.get() or {}).items():
if hasattr(engine, "aclose"):
try:
await engine.aclose()
logger.debug("LocalSearch engine [%s] async closed.", name)
except Exception as e:
logger.warning("Failed to async close local search engine [%s]: %s", name, e)
@staticmethod
def _reset_context_tokens(llm_token, web_search_token, local_search_token):
if llm_token is not None:
llm_context.reset(llm_token)
if web_search_token is not None:
web_search_context.reset(web_search_token)
if local_search_token is not None:
local_search_context.reset(local_search_token)
@staticmethod
def _build_stream_error_payload(conversation_id: str, final_result_info: dict):
return {
"conversation_id": conversation_id,
"message_id": str(uuid.uuid4()),
"agent": NodeId.FRAMEWORK.value,
"role": "assistant",
"content": json.dumps(final_result_info, ensure_ascii=False),
"message_type": MessageType.MESSAGE_CHUNK.value,
"event": StreamEvent.ERROR.value,
"created_time": get_current_time(),
}
@staticmethod
def _build_stream_end_payload(conversation_id: str):
return {
"conversation_id": conversation_id,
"message_id": str(uuid.uuid4()),
"agent": NodeId.FRAMEWORK.value,
"role": "assistant",
"content": "ALL END",
"message_type": MessageType.MESSAGE_CHUNK.value,
"event": StreamEvent.SUMMARY_RESPONSE.value,
"created_time": get_current_time(),
}
@staticmethod
async def _emit_error_and_end_stream(conversation_id: str, final_result_info: dict):
try:
yield json.dumps(
DeepresearchAgent._build_stream_error_payload(conversation_id, final_result_info), ensure_ascii=False
)
yield json.dumps(DeepresearchAgent._build_stream_end_payload(conversation_id), ensure_ascii=False)
except Exception as stream_err:
logger.warning("[DeepResearchAgent.run] Failed to emit error stream event: %s", stream_err)
@staticmethod
def _prepare_stream_query(message: str, interrupt_feedback: str):
is_report_feedback = _is_report_feedback_payload(message)
if interrupt_feedback and not is_report_feedback:
return json.dumps({"interrupt_feedback": interrupt_feedback, "feedback": message}), is_report_feedback
return message, is_report_feedback
async def _consume_stream_chunks(
self,
conversation_id: str,
message: str,
decoded_template: str,
interrupt_feedback: str,
session_agent_config: dict,
):
is_all_end = False
final_result_info = {}
filter_dup_flag = False
stream_query, is_report_feedback = self._prepare_stream_query(message, interrupt_feedback)
workflow_agent_config = _redact_agent_config_for_workflow_inputs(session_agent_config)
async for chunk in Runner.run_agent_streaming(
agent=self.agent,
inputs={
"query": stream_query,
"thread_id": conversation_id,
"conversation_id": conversation_id,
"report_template": decoded_template,
"interrupt_feedback": interrupt_feedback,
"resume_interaction": is_report_feedback,
"agent_config": workflow_agent_config,
},
):
if getattr(chunk, "type", "") == "__interaction__":
filter_dup_flag = False
yield self._build_interrupt_message(conversation_id, chunk), is_all_end, final_result_info
continue
if filter_dup_flag:
continue
if isinstance(chunk, CustomSchema):
agent = getattr(chunk, "agent", "")
event = getattr(chunk, "event", "")
if agent == NodeId.GENERATE_QUESTIONS.value and event == StreamEvent.DONE.value:
filter_dup_flag = True
endnode_info = parse_endnode_content(chunk)
if endnode_info:
final_result_info = endnode_info
if getattr(chunk, "content", "") == "ALL END":
is_all_end = True
yield self._build_output_message(conversation_id, chunk), is_all_end, final_result_info
async def run(
self,
message: Optional[str] = None,
conversation_id: Optional[str] = None,
agent_config: Optional[dict] = None,
report_template: str = "",
interrupt_feedback: str = "",
):
"""执行一次 workflow 并以流式方式返回消息。
Args:
message: 用户输入消息或反馈内容。
conversation_id: 会话 ID,同时作为 workflow thread_id。
agent_config: 本次运行的 Agent 配置字典。
report_template: 报告模板(支持 base64 或明文)。
interrupt_feedback: 交互中断反馈标识。
Yields:
str: JSON 序列化后的流式事件消息。
Raises:
CustomValueException: 参数校验失败或配置不合法时抛出。
"""
validate_run_agent_params(message, conversation_id, report_template, interrupt_feedback)
validate_agent_required_field(agent_config)
start_time = time.time()
llm_token = None
web_search_token = None
local_search_token = None
try:
session_agent_config = AgentConfig.model_validate(agent_config)
llm_configs = session_agent_config.llm_config
if LlmConfigCategory.GENERAL.value not in llm_configs:
raise CustomValueException(
error_code=StatusCode.LLM_CONFIG_NONE.code, message=StatusCode.LLM_CONFIG_NONE.errmsg
)
thinking_enabled = resolve_llm_thinking_enabled()
all_llms = {}
for _, llm_config in llm_configs.items():
llm_obj = create_llm_obj(llm_config, thinking_enabled=thinking_enabled)
all_llms[llm_config.model_name] = llm_obj
llm_token = llm_context.set(all_llms)
web_search_token, local_search_token = self._initialize_tools(session_agent_config)
await self._aopen_local_search_engines()
except CustomValueException:
self._reset_context_tokens(llm_token, web_search_token, local_search_token)
raise
except ValidationError as e:
self._reset_context_tokens(llm_token, web_search_token, local_search_token)
if LogManager.is_sensitive():
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR_NO_PRINT.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR_NO_PRINT.errmsg,
) from e
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(e=str(e)),
) from e
token = session_id_ctx.set(conversation_id)
stats_info_llm_enabled = bool(session_agent_config.stats_info_llm)
decoded_template = report_template
if report_template:
decoded_template = self._handle_report_template(report_template)
is_all_end = False
final_result_info = {}
try:
session_agent_config = session_agent_config.model_dump()
async for payload, stream_end, stream_info in self._consume_stream_chunks(
conversation_id=conversation_id,
message=message,
decoded_template=decoded_template,
interrupt_feedback=interrupt_feedback,
session_agent_config=session_agent_config,
):
is_all_end = stream_end
final_result_info = stream_info
yield payload
except Exception as e:
if not LogManager.is_sensitive() or isinstance(e, CustomValueException):
logger.error(f"[DeepResearchAgent.run] Session closed with error: {e}")
final_result_info = {"exception_info": str(e)}
else:
logger.error(f"[DeepResearchAgent.run] Session closed with error.")
final_result_info = {"exception_info": "Session closed with error."}
if stats_info_llm_enabled:
try:
current_session = session_context.get()
except Exception:
current_session = None
workflow_usage = get_effective_workflow_llm_usage(
session_id=conversation_id,
session=current_session,
)
if not is_workflow_llm_usage_empty(workflow_usage):
final_result_info["workflow_llm_token_usage"] = workflow_usage
async for payload in self._emit_error_and_end_stream(conversation_id, final_result_info):
yield payload
if stats_info_llm_enabled:
pop_workflow_llm_usage(conversation_id)
await self.agent.release_session(conversation_id)
await self._release_checkpointer_session(conversation_id)
session_id_ctx.reset(token)
finally:
metrics_logger.info(
f"{TIME_LOGGER_TAG} thread_id: {conversation_id} ------ [DeepResearchAgent[0].run]"
f" executed time: {(time.time() - start_time) :.2f} s"
)
record_interface_log(
role="SVR",
session_id=conversation_id,
api_name="run",
duration_min=(time.time() - start_time) / 60,
success=not bool(final_result_info.get("exception_info")),
response_info=final_result_info if bool(final_result_info.get("exception_info")) else {},
)
try:
await self._aclose_local_search_engines()
except Exception as e:
if not LogManager.is_sensitive():
logger.warning(f"Failed to close local search engines: {e}")
else:
logger.warning(f"Failed to close local search engines.")
finally:
self._reset_context_tokens(llm_token, web_search_token, local_search_token)
if is_all_end:
zero_secret(
session_agent_config.get("web_search_engine_config", {}).get(
"search_api_key", bytearray("", encoding="utf-8")
)
)
zero_secret(
session_agent_config.get("local_search_engine_config", {}).get(
"search_api_key", bytearray("", encoding="utf-8")
)
)
await self.agent.release_session(conversation_id)
await self._release_checkpointer_session(conversation_id)
session_id_ctx.reset(token)
if stats_info_llm_enabled and is_all_end:
pop_workflow_llm_usage(conversation_id)
def _build_research_workflow(self):
_id = self.research_name
name = self.research_name
version = self.version
card = WorkflowCard(
id=_id,
version=version,
name=name,
)
flow = Workflow(card=card)
flow.set_start_comp(
start_comp_id=NodeId.START.value, component=StartNode(), inputs_schema=self.startnode_input_schema
)
flow.add_workflow_comp(NodeId.INTENT_RECOGNITION.value, IntentRecognitionNode())
flow.add_workflow_comp(NodeId.ENTRY.value, EntryNode())
flow.add_workflow_comp(NodeId.GENERATE_QUESTIONS.value, GenerateQuestionsNode())
flow.add_workflow_comp(NodeId.FEEDBACK_HANDLER.value, FeedbackHandlerNode())
flow.add_workflow_comp(NodeId.OUTLINE.value, OutlineNode())
flow.add_workflow_comp(NodeId.OUTLINE_INTERACTION.value, OutlineInteractionNode())
flow.add_workflow_comp(NodeId.EDITOR_TEAM.value, EditorTeamNode())
flow.add_workflow_comp(NodeId.REPORTER.value, ReporterNode())
flow.add_workflow_comp(NodeId.VLM_CHART_GENERATOR.value, VLMChartGeneratorNode())
flow.add_workflow_comp(NodeId.SOURCE_TRACER.value, SourceTracerNode())
flow.add_workflow_comp(NodeId.SOURCE_TRACER_INFER.value, SourceTracerInferNode())
flow.add_workflow_comp(NodeId.USER_FEEDBACK_PROCESSOR.value, UserFeedbackProcessorNode())
flow.set_end_comp(NodeId.END.value, EndNode())
flow.add_connection(NodeId.START.value, NodeId.INTENT_RECOGNITION.value)
entry_router = init_router(
NodeId.ENTRY.value, [NodeId.OUTLINE.value, NodeId.GENERATE_QUESTIONS.value, NodeId.END.value]
)
intent_recognition_router = init_router(
NodeId.INTENT_RECOGNITION.value, NodeId.ENTRY.value
)
generate_questions_router = init_router(
NodeId.GENERATE_QUESTIONS.value, [NodeId.FEEDBACK_HANDLER.value, NodeId.END.value]
)
outline_router = init_router(
NodeId.OUTLINE.value, [NodeId.OUTLINE_INTERACTION.value, NodeId.EDITOR_TEAM.value, NodeId.END.value]
)
outline_interaction_router = init_router(
NodeId.OUTLINE_INTERACTION.value, [NodeId.OUTLINE.value, NodeId.EDITOR_TEAM.value, NodeId.END.value]
)
reporter_router = init_router(NodeId.REPORTER.value, [NodeId.END.value, NodeId.VLM_CHART_GENERATOR.value])
feedback_handler_router = init_router(NodeId.FEEDBACK_HANDLER.value, [NodeId.OUTLINE.value, NodeId.END.value])
editor_team_router = init_router(NodeId.EDITOR_TEAM.value, [NodeId.REPORTER.value, NodeId.END.value])
user_feedback_processor_router = init_router(
NodeId.USER_FEEDBACK_PROCESSOR.value, [NodeId.USER_FEEDBACK_PROCESSOR.value, NodeId.END.value]
)
flow.add_conditional_connection(NodeId.INTENT_RECOGNITION.value, router=intent_recognition_router)
flow.add_conditional_connection(NodeId.ENTRY.value, router=entry_router)
flow.add_conditional_connection(NodeId.GENERATE_QUESTIONS.value, router=generate_questions_router)
flow.add_conditional_connection(NodeId.OUTLINE.value, router=outline_router)
flow.add_conditional_connection(NodeId.FEEDBACK_HANDLER.value, router=feedback_handler_router)
flow.add_conditional_connection(NodeId.REPORTER.value, router=reporter_router)
flow.add_conditional_connection(NodeId.EDITOR_TEAM.value, router=editor_team_router)
flow.add_conditional_connection(NodeId.OUTLINE_INTERACTION.value, router=outline_interaction_router)
flow.add_connection(NodeId.VLM_CHART_GENERATOR.value, NodeId.SOURCE_TRACER.value)
flow.add_connection(NodeId.SOURCE_TRACER.value, NodeId.SOURCE_TRACER_INFER.value)
flow.add_connection(NodeId.SOURCE_TRACER_INFER.value, NodeId.USER_FEEDBACK_PROCESSOR.value)
flow.add_conditional_connection(NodeId.USER_FEEDBACK_PROCESSOR.value, router=user_feedback_processor_router)
return flow
def _create_research_workflow_agent(self):
"""创建Deepresearch工作流Agent实例"""
workflow_card = WorkflowCard(
id=self.research_name,
version=self.version,
name=self.research_name,
description=self.research_name,
input_params=self.workflow_input_schema,
)
card = AgentCard(
id=self.research_name,
name=self.research_name,
description=self.research_name,
)
config = WorkflowControllerConfig(
id=self.research_name,
version=self.version,
description=self.research_name,
workflows=[workflow_card],
)
self.agent = WorkflowAgent(card=card, config=config)
self.agent.add_workflows([self._build_workflow_provider(self._build_research_workflow, workflow_card)])
def _handle_report_template(self, report_template):
decoded_template = None
try:
decoded_template = base64.b64decode(report_template).decode("utf-8")
logging.debug("[DeepresearchAgent.run] Successfully decoded base64 report_template")
except Exception as e:
if not LogManager.is_sensitive():
logging.warning(f"[DeepresearchAgent.run] Failed to decode base64 report template: {e}")
else:
logging.warning(f"[DeepresearchAgent.run] Failed to decode base64 report template.")
decoded_template = report_template
return decoded_template
def _initialize_tools(self, agent_config: AgentConfig):
"""初始化搜索工具"""
custom_web = agent_config.custom_web_search_config
custom_local = agent_config.custom_local_search_config
web_search_config = agent_config.web_search_engine_config
local_search_config = agent_config.local_search_engine_config
web_engine_name, web_mapping = self._register_web_search_tool(custom_web, web_search_config)
local_engine_name, local_mapping = self._register_local_search_tool(custom_local, local_search_config)
web_search_token = web_search_context.set(
{web_engine_name: web_mapping[web_engine_name](**web_search_config.model_dump())}
)
local_search_token = local_search_context.set(
{local_engine_name: local_mapping[local_engine_name](**local_search_config.model_dump())}
)
qps_limiter = qps_rate_limiter
qps_limiter.set_max_qps(agent_config.web_search_max_qps)
return web_search_token, local_search_token
class DeepresearchDependencyAgent(DeepresearchAgent):
"""
Deepresearch agent: 生成报告 Agent,通用模型,依赖驱动执行任务,不带模板
"""
def _get_default_research_name(self) -> str:
return "research_workflow_dependency_driving"
def _create_research_workflow_agent(self):
workflow_card = WorkflowCard(
id=self.research_name,
version=self.version,
name=self.research_name,
description=self.research_name,
input_params=self.workflow_input_schema,
)
card = AgentCard(
id=self.research_name,
name=self.research_name,
description=self.research_name,
)
config = WorkflowControllerConfig(
id=self.research_name,
version=self.version,
description=self.research_name,
workflows=[workflow_card],
)
self.agent = WorkflowAgent(card=card, config=config)
self.agent.add_workflows(
[self._build_workflow_provider(self._build_research_dependency_workflow, workflow_card)]
)
def _build_research_dependency_workflow(self):
_id = self.research_name
name = self.research_name
version = self.version
card = WorkflowCard(
id=_id,
version=version,
name=name,
)
flow = Workflow(card=card)
flow.set_start_comp(
start_comp_id=NodeId.START.value, component=StartNode(), inputs_schema=self.startnode_input_schema
)
flow.add_workflow_comp(NodeId.INTENT_RECOGNITION.value, IntentRecognitionNode())
flow.add_workflow_comp(NodeId.ENTRY.value, EntryNode())
flow.add_workflow_comp(NodeId.GENERATE_QUESTIONS.value, GenerateQuestionsNode())
flow.add_workflow_comp(NodeId.FEEDBACK_HANDLER.value, FeedbackHandlerNode())
flow.add_workflow_comp(NodeId.OUTLINE.value, DependencyOutlineNode())
flow.add_workflow_comp(NodeId.OUTLINE_INTERACTION.value, DependencyOutlineInteractionNode())
flow.add_workflow_comp(NodeId.DEPENDENCY_EDITOR_TEAM.value, DependencyEditorTeamNode())
flow.add_workflow_comp(NodeId.REPORTER.value, ReporterNode())
flow.add_workflow_comp(NodeId.VLM_CHART_GENERATOR.value, VLMChartGeneratorNode())
flow.add_workflow_comp(NodeId.SOURCE_TRACER.value, SourceTracerNode())
flow.add_workflow_comp(NodeId.SOURCE_TRACER_INFER.value, SourceTracerInferNode())
flow.add_workflow_comp(NodeId.USER_FEEDBACK_PROCESSOR.value, UserFeedbackProcessorNode())
flow.set_end_comp(NodeId.END.value, EndNode())
flow.add_connection(NodeId.START.value, NodeId.INTENT_RECOGNITION.value)
entry_router = init_router(
NodeId.ENTRY.value, [NodeId.OUTLINE.value, NodeId.GENERATE_QUESTIONS.value, NodeId.END.value]
)
intent_recognition_router = init_router(
NodeId.INTENT_RECOGNITION.value, NodeId.ENTRY.value
)
generate_questions_router = init_router(
NodeId.GENERATE_QUESTIONS.value, [NodeId.FEEDBACK_HANDLER.value, NodeId.END.value]
)
outline_router = init_router(
NodeId.OUTLINE.value,
[NodeId.OUTLINE_INTERACTION.value, NodeId.DEPENDENCY_EDITOR_TEAM.value, NodeId.END.value],
)
outline_interaction_router = init_router(
NodeId.OUTLINE_INTERACTION.value,
[NodeId.OUTLINE.value, NodeId.DEPENDENCY_EDITOR_TEAM.value, NodeId.END.value],
)
reporter_router = init_router(NodeId.REPORTER.value, [NodeId.END.value, NodeId.VLM_CHART_GENERATOR.value])
feedback_handler_router = init_router(NodeId.FEEDBACK_HANDLER.value, [NodeId.OUTLINE.value, NodeId.END.value])
dependency_editor_router = init_router(
NodeId.DEPENDENCY_EDITOR_TEAM.value, [NodeId.REPORTER.value, NodeId.END.value]
)
user_feedback_processor_router = init_router(
NodeId.USER_FEEDBACK_PROCESSOR.value, [NodeId.USER_FEEDBACK_PROCESSOR.value, NodeId.END.value]
)
flow.add_conditional_connection(NodeId.INTENT_RECOGNITION.value, router=intent_recognition_router)
flow.add_conditional_connection(NodeId.ENTRY.value, router=entry_router)
flow.add_conditional_connection(NodeId.GENERATE_QUESTIONS.value, router=generate_questions_router)
flow.add_conditional_connection(NodeId.OUTLINE.value, router=outline_router)
flow.add_conditional_connection(NodeId.FEEDBACK_HANDLER.value, router=feedback_handler_router)
flow.add_conditional_connection(NodeId.OUTLINE_INTERACTION.value, router=outline_interaction_router)
flow.add_conditional_connection(NodeId.REPORTER.value, router=reporter_router)
flow.add_conditional_connection(NodeId.DEPENDENCY_EDITOR_TEAM.value, router=dependency_editor_router)
flow.add_connection(NodeId.VLM_CHART_GENERATOR.value, NodeId.SOURCE_TRACER.value)
flow.add_connection(NodeId.SOURCE_TRACER.value, NodeId.SOURCE_TRACER_INFER.value)
flow.add_connection(NodeId.SOURCE_TRACER_INFER.value, NodeId.USER_FEEDBACK_PROCESSOR.value)
flow.add_conditional_connection(NodeId.USER_FEEDBACK_PROCESSOR.value, router=user_feedback_processor_router)
return flow
class DeepSearchAgent(BaseAgent):
def __init__(self) -> None:
self.version: str = "1"
self.action_pool: ActionPool = ActionPool()
self.completed_actions: list[tuple[Action, Result | None]] = []
self.final_answer: str | None = None
self.fail_count: int = 0
self.total_input_tokens: int = 0
self.total_output_tokens: int = 0
self.log_dir: str = ""
self.time_limit: int = 0
self.query: str = ""
self.gold_answer: str | None = None
self.tool_map: dict[str, Any] = {}
self.agent_config: AgentConfig | None = None
self.per_question_params: PerQuestionParams | None = None
self.search_config: SearchWorkflowConfig | None = None
def setup_log_directory(self, save_as: str) -> None:
"""Create a timestamp-based log directory with Action and Result subfolders"""
base_log_dir: str | None = LogManager.get_log_dir()
self.log_dir = os.path.join(base_log_dir, save_as)
os.makedirs(os.path.join(self.log_dir, "Action"), exist_ok=True)
os.makedirs(os.path.join(self.log_dir, "Result"), exist_ok=True)
self.action_pool.log_dir = self.log_dir
def _subworkflow_context_inputs(self, workflow_name: str) -> dict[str, Any]:
"""Build per-run config inputs for DeepSearch sub-workflows.
Sub-workflows are globally registered with fixed IDs, so runtime inputs must
carry the active request config to avoid stale model settings across runs.
"""
agent_config = to_dict_safe(self.agent_config)
search_config = to_dict_safe(self.search_config)
if not isinstance(agent_config, dict):
agent_config = {}
if not isinstance(search_config, dict):
search_config = Config().service_config.search_workflow.model_dump()
return {
"workflow_name": workflow_name,
"agent_config": agent_config,
"search_config": search_config,
}
def _build_init_state_workflow(self) -> Workflow:
card = WorkflowCard(id="init_state", version="1", name="init_state")
wf = Workflow(card=card)
wf.set_start_comp(
start_comp_id=NodeId.START_NODE.value,
component=SearchStartNode(),
inputs_schema={
"workflow_name": "init_state_workflow",
},
)
wf.add_workflow_comp(NodeId.INITIAL_STATE.value, InitializeStateNode())
wf.set_end_comp(NodeId.END_NODE.value, SearchEndNode())
wf.add_connection(NodeId.START_NODE.value, NodeId.INITIAL_STATE.value)
wf.add_connection(NodeId.INITIAL_STATE.value, NodeId.END_NODE.value)
return wf
def _build_find_action_workflow(self) -> Workflow:
card = WorkflowCard(id="find_action", version="1", name="find_action")
wf = Workflow(card=card)
wf.set_start_comp(
start_comp_id=NodeId.START_NODE.value,
component=SearchStartNode(),
inputs_schema={
"workflow_name": "find_action_workflow",
},
)
wf.add_workflow_comp(NodeId.FIND_ACTION_SPACE.value, FindActionSpaceNode())
wf.set_end_comp(NodeId.END_NODE.value, SearchEndNode())
wf.add_connection(NodeId.START_NODE.value, NodeId.FIND_ACTION_SPACE.value)
wf.add_connection(NodeId.FIND_ACTION_SPACE.value, NodeId.END_NODE.value)
return wf
def _build_state_creation_workflow(self) -> Workflow:
card = WorkflowCard(id="state_creation", version="1", name="state_creation")
wf = Workflow(card=card)
wf.set_start_comp(
start_comp_id=NodeId.START_NODE.value,
component=SearchStartNode(),
inputs_schema={"workflow_name": "state_creation_workflow"},
)
wf.add_workflow_comp(NodeId.TOOL.value, ToolNode())
wf.add_workflow_comp(NodeId.RUN_ACTION.value, RunActionNode())
wf.add_workflow_comp(NodeId.VALIDATE_NEW_STATE.value, ValidateNewStateNode())
wf.set_end_comp(NodeId.END_NODE.value, SearchEndNode())
wf.add_connection(NodeId.START_NODE.value, NodeId.RUN_ACTION.value)
wf.add_connection(NodeId.TOOL.value, NodeId.RUN_ACTION.value)
run_iter = init_router(
NodeId.RUN_ACTION.value,
[
NodeId.TOOL.value,
NodeId.RUN_ACTION.value,
NodeId.VALIDATE_NEW_STATE.value,
NodeId.END_NODE.value,
],
)
validator_router = init_router(
NodeId.VALIDATE_NEW_STATE.value,
[NodeId.RUN_ACTION.value, NodeId.END_NODE.value],
)
wf.add_conditional_connection(NodeId.RUN_ACTION.value, router=run_iter)
wf.add_conditional_connection(NodeId.VALIDATE_NEW_STATE.value, router=validator_router)
return wf
def _build_agent(self):
schemas = [
WorkflowCard(
id="init_state",
version=self.version,
name="init_state",
description="init_state",
input_params={"query": str, "total_input_tokens": int, "total_output_tokens": int, "log_dir": str},
),
WorkflowCard(
id="find_action",
version=self.version,
name="find_action",
description="find_action",
input_params={
"state": State,
"result": Result,
"query": str,
"log_dir": str,
"total_input_tokens": int,
"total_output_tokens": int,
},
),
WorkflowCard(
id="state_creation",
version=self.version,
name="state_creation",
description="state_creation",
input_params={
"action": Action,
"log_dir": str,
"fail_count": int,
"retrieval_tool_only": bool,
"total_input_tokens": int,
"total_output_tokens": int,
},
),
]
self.agent = LegacyWorkflowAgent(
WorkflowAgentConfig(
id="deepsearch_sub_workflows",
description="DeepSearch init/find/state_creation subgraphs",
workflows=schemas,
)
)
state_creation_factory = WorkflowFactory(
workflow_id="state_creation",
workflow_version="1",
factory=self._build_state_creation_workflow,
workflow_name="state_creation",
workflow_description="state_creation",
input_schema={
"action": Action,
"log_dir": str,
"fail_count": int,
"retrieval_tool_only": bool,
"total_input_tokens": int,
"total_output_tokens": int,
},
)
self.agent.add_workflows(
[
self._build_init_state_workflow(),
self._build_find_action_workflow(),
state_creation_factory,
]
)
async def _cancel_running_tasks(
self,
running_tasks: set[asyncio.Task],
task_to_action: dict[asyncio.Task, Action],
) -> None:
"""Cancel all in-flight state_creation tasks and mark them as completed."""
if not running_tasks:
return
count = len(running_tasks)
for task in running_tasks:
task.cancel()
await asyncio.gather(*running_tasks, return_exceptions=True)
for task in list(running_tasks):
action = task_to_action.pop(task, None)
if action is not None:
self.action_pool.record_completed(action, None)
running_tasks.clear()
logger.info("[DeepSearchAgent] cancelled %d remaining tasks", count)
async def run_state_creation_workflow(
self,
action: Any,
semaphore: asyncio.Semaphore,
) -> Any:
async with semaphore:
return await Runner.run_workflow(
workflow="state_creation_1",
inputs={
**self._subworkflow_context_inputs("state_creation_workflow"),
"action": to_dict_safe(action),
"retrieval_tool_only": "retrieve" in self.tool_map,
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"log_dir": self.log_dir,
"fail_count": self.fail_count,
},
)
async def _run_internal(self) -> SearchFinalResult:
start_time: float = time.time()
max_workers: int = self.per_question_params.max_workers
sem = asyncio.Semaphore(max_workers)
max_tries: int = self.search_config.init_state_agent.max_tries
init_state_result: dict[str, Any] | None = None
last_exception: Exception | None = None
try:
init_result: WorkflowOutput = await Runner.run_workflow(
workflow="init_state_1",
inputs={
**self._subworkflow_context_inputs("init_state_workflow"),
"query": self.query,
"total_input_tokens": 0,
"total_output_tokens": 0,
"log_dir": self.log_dir,
},
)
init_state_result = parse_and_validate_init_state_result(init_result)
except Exception as e:
last_exception = e
logger.error(f"[DeepSearchAgent] init_state error, failed all {max_tries} attempts: {e}")
if init_state_result is None:
raise CustomValueException(
StatusCode.AGENT_INIT_STATE_ERROR.code,
StatusCode.AGENT_INIT_STATE_ERROR.errmsg,
) from last_exception
self.total_input_tokens += init_state_result.get("total_input_tokens", 0)
self.total_output_tokens += init_state_result.get("total_output_tokens", 0)
init_state: State = init_state_result.get("init_state")
logger.info(f"[DeepSearchAgent] initial state: %s", "***" if LogManager.is_sensitive() else init_state)
actions_result: WorkflowOutput = await Runner.run_workflow(
workflow="find_action_1",
inputs={
**self._subworkflow_context_inputs("find_action_workflow"),
"state": init_state,
"query": self.query,
"result": None,
"log_dir": self.log_dir,
"total_input_tokens": 0,
"total_output_tokens": 0,
},
)
actions_dict: dict[str, Any] = parse_and_validate_find_action_result(actions_result.result)
logger.info(
f"[DeepSearchAgent] initial actions: %s",
"***" if LogManager.is_sensitive() else [a.proposal.direction for a in actions_dict.get("actions", [])],
)
self.total_input_tokens += actions_dict.get("total_input_tokens", 0)
self.total_output_tokens += actions_dict.get("total_output_tokens", 0)
self.action_pool.add(actions_dict.get("actions", []))
actions_explored: int = 0
actions_explored_limit: int = self.per_question_params.actions_explored_limit
fail_limit: int = self.per_question_params.fail_limit
answer_mode_top_k: int = self.per_question_params.answer_mode_top_k
provide_best_guess: bool = self.per_question_params.provide_best_guess
running_tasks: set[asyncio.Task[WorkflowOutput]] = set()
task_to_action: dict[asyncio.Task[WorkflowOutput], Action] = {}
config: dict = {}
termination_reason: Termination = Termination.ACTION_POOL_DEPLETED
find_action_pool_depleted_retries_left: int = self.per_question_params.retry_count_on_empty_action_space
while not self.final_answer:
if (time.time() - start_time) > self.time_limit:
termination_reason = Termination.TIME_LIMIT
logger.info("[DeepSearchAgent] %s", termination_reason.log_message)
break
if actions_explored_limit > 0 and actions_explored >= actions_explored_limit:
termination_reason = Termination.ACTIONS_EXPLORED_LIMIT
logger.info("[DeepSearchAgent] %s (%d)", termination_reason.log_message, actions_explored_limit)
break
if fail_limit > 0 and self.fail_count >= fail_limit:
termination_reason = Termination.FAIL_LIMIT
logger.info("[DeepSearchAgent] %s (%d)", termination_reason.log_message, fail_limit)
break
available_slots: int = max_workers - len(running_tasks)
if available_slots > 0 and self.action_pool.size() > 0:
sampled: list[Action] = self.action_pool.sample(available_slots)
logger.info(
f"[DeepSearchAgent] sampled actions (dynamic): %s",
"***" if LogManager.is_sensitive() else [a.proposal.direction for a in sampled],
)
for action in sampled:
task = asyncio.create_task(
self.run_state_creation_workflow(
action=action,
semaphore=sem,
)
)
running_tasks.add(task)
task_to_action[task] = action
if not running_tasks:
if find_action_pool_depleted_retries_left <= 0:
logger.info(
"[DeepSearchAgent] action pool empty and no running tasks; "
"find_action retries exhausted (retry_count_on_empty_action_space=%d)",
self.per_question_params.retry_count_on_empty_action_space,
)
break
find_action_pool_depleted_retries_left -= 1
logger.info(
"[DeepSearchAgent] action pool empty; re-running find_action "
"(retries remaining after this call: %d)",
find_action_pool_depleted_retries_left,
)
retry_strategy = self.search_config.find_action_agent.action_pool_depleted_strategy
retry_result = None
if retry_strategy == "dependent_retry" and self.action_pool.completed_actions:
failed_summaries: list[str] = []
for i, (action, action_result) in enumerate(self.action_pool.completed_actions, 1):
entry = f"{i}. Direction: {action.proposal.direction}"
if isinstance(action_result, Result) and action_result.messages:
entry += action_result.get_summary()
failed_summaries.append(entry)
prompt_content: str = get_prompt_section(
"deepsearch_dependent_retry_find_action",
{
"failed_count": len(failed_summaries),
"failed_summaries": "\n\n".join(failed_summaries),
},
)
retry_result = Result(
messages=[{"role": "user", "content": prompt_content}],
new_states=[],
found_answer=None,
previous_action_id=action.id,
)
actions_result = await Runner.run_workflow(
workflow="find_action_1",
inputs={
**self._subworkflow_context_inputs("find_action_workflow"),
"state": init_state,
"query": self.query,
"result": retry_result,
"log_dir": self.log_dir,
"total_input_tokens": 0,
"total_output_tokens": 0,
},
)
actions_dict = parse_and_validate_find_action_result(actions_result.result)
self.action_pool.add(actions_dict.get("actions", []))
continue
done: set[asyncio.Task[WorkflowOutput]]
done, _ = await asyncio.wait(
running_tasks,
return_when=asyncio.FIRST_COMPLETED,
)
for task in done:
running_tasks.remove(task)
completed_action: Action = task_to_action.pop(task)
actions_explored += 1
states: WorkflowOutput = await task
logger.info(f"[DeepSearchAgent] action result: %s", "***" if LogManager.is_sensitive() else states)
state_result: dict[str, Any] = parse_and_validate_state_creation_result(states.result)
self.total_input_tokens += state_result.get("total_input_tokens", 0)
self.total_output_tokens += state_result.get("total_output_tokens", 0)
result: Result | None = state_result.get("result")
config = anonymize_config_for_logging(state_result.get("config", {}))
self.fail_count += config.get("fail_count", self.fail_count)
self.action_pool.record_completed(completed_action, result)
if not isinstance(result, Result):
continue
if result.found_answer:
if answer_mode_top_k <= 1:
self.final_answer = result.found_answer
logger.info(
f"[DeepSearchAgent] found final answer! %s",
"***" if LogManager.is_sensitive() else self.final_answer,
)
await self._cancel_running_tasks(running_tasks, task_to_action)
return _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=self.query,
messages=result.messages,
prediction=result.found_answer,
gold_answer=self.gold_answer,
termination=Termination.ANSWER,
retrieved_evidence_ids=result.retrieved_evidence_ids,
params={
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"start_time": start_time,
"log_dir": self.log_dir,
},
config=config,
)
)
else:
self.action_pool.record_successful_answer(completed_action, result)
collected = self.action_pool.successful_answer_count()
logger.info(
"[DeepSearchAgent] top-k mode: collected %d/%d answers%s",
collected,
answer_mode_top_k,
" ***" if LogManager.is_sensitive() else f" (latest: {result.found_answer})",
)
if collected >= answer_mode_top_k:
best_action, best_result = self.action_pool.get_best_answer()
self.final_answer = best_result.found_answer
logger.info(
"[DeepSearchAgent] top-k mode: returning best answer%s",
" ***" if LogManager.is_sensitive() else f": {self.final_answer}",
)
await self._cancel_running_tasks(running_tasks, task_to_action)
return _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=self.query,
messages=best_result.messages,
prediction=best_result.found_answer,
gold_answer=self.gold_answer,
termination=Termination.ANSWER,
retrieved_evidence_ids=best_result.retrieved_evidence_ids,
params={
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"start_time": start_time,
"log_dir": self.log_dir,
},
config=config,
)
)
continue
if result.new_states:
all_new_actions: list[Action] = []
for new_state in result.new_states:
new_actions_result: WorkflowOutput = await Runner.run_workflow(
workflow="find_action_1",
inputs={
**self._subworkflow_context_inputs("find_action_workflow"),
"state": new_state,
"query": self.query,
"result": result,
"log_dir": self.log_dir,
"total_input_tokens": 0,
"total_output_tokens": 0,
},
)
new_actions_dict: dict[str, Any] = parse_and_validate_find_action_result(
new_actions_result.result
)
new_actions: list[Action] = new_actions_dict.get("actions", [])
all_new_actions.extend(new_actions)
self.total_input_tokens += new_actions_dict.get("total_input_tokens", 0)
self.total_output_tokens += new_actions_dict.get("total_output_tokens", 0)
logger.info(
f"[DeepSearchAgent] new actions: %s",
"***" if LogManager.is_sensitive() else [a.proposal.direction for a in new_actions],
)
self.action_pool.add(all_new_actions)
await self._cancel_running_tasks(running_tasks, task_to_action)
best_pair = self.action_pool.get_best_answer()
if best_pair is not None:
best_action, best_result = best_pair
self.final_answer = best_result.found_answer
effective_termination = (
Termination.TIMEOUT_ANSWER if termination_reason == Termination.TIME_LIMIT else termination_reason
)
logger.info(
"[DeepSearchAgent] %s: returning best collected answer%s",
effective_termination.name,
" ***" if LogManager.is_sensitive() else f": {self.final_answer}",
)
return _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=self.query,
messages=best_result.messages,
prediction=best_result.found_answer,
gold_answer=self.gold_answer,
termination=effective_termination,
retrieved_evidence_ids=best_result.retrieved_evidence_ids,
params={
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"start_time": start_time,
"log_dir": self.log_dir,
},
config=config,
)
)
if termination_reason == Termination.TIME_LIMIT and provide_best_guess:
guess_triple = self.action_pool.get_best_guess()
if guess_triple is not None:
guess_action, guess_result, guess_candidate = guess_triple
self.final_answer = guess_candidate
logger.info(
"[DeepSearchAgent] TIMEOUT_GUESS: returning best-guess candidate%s",
" ***" if LogManager.is_sensitive() else f": {self.final_answer}",
)
return _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=self.query,
messages=guess_result.messages if guess_result else [],
prediction=guess_candidate,
gold_answer=self.gold_answer,
termination=Termination.TIMEOUT_GUESS,
retrieved_evidence_ids=(guess_result.retrieved_evidence_ids if guess_result else []),
params={
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"start_time": start_time,
"log_dir": self.log_dir,
},
config=config,
)
)
return _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=self.query,
messages=[],
prediction=self.final_answer,
gold_answer=self.gold_answer,
termination=termination_reason,
retrieved_evidence_ids=[],
params={
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"start_time": start_time,
"log_dir": self.log_dir,
},
config=config,
)
)
async def run(
self,
message: str,
conversation_id: str,
agent_config: dict,
report_template: str = "",
interrupt_feedback: str = "",
) -> AsyncGenerator[str, None]:
"""Run DeepSearch with the same surface as ``BaseAgent.run``.
Optional per-run fields may be supplied inside ``agent_config`` and are
removed before ``AgentConfig`` validation: ``service_config`` (dict, e.g.
``{"search_workflow": ...}``) and ``gold_answer`` (str | None).
"""
validate_run_agent_params(message, conversation_id, report_template, interrupt_feedback)
if not isinstance(agent_config, dict):
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e="agent_config must be a dict",
),
)
agent_config_for_model = dict(agent_config)
service_config: Optional[dict] = agent_config_for_model.pop("service_config", None)
gold_answer: str | None = agent_config_for_model.pop("gold_answer", None)
validate_agent_required_field(agent_config_for_model)
llm_token = None
tool_token = None
try:
session_agent_config = AgentConfig.model_validate(agent_config_for_model)
self.agent_config = session_agent_config.model_copy(deep=True)
self.setup_log_directory(f"result_{conversation_id}")
logger.info(f"[DeepSearchAgent] agent_config: {self.agent_config}")
try:
self.search_config = SearchWorkflowConfig.model_validate(
(service_config or {}).get("search_workflow", {})
)
except Exception as e:
logger.warning(
"[DeepSearchAgent] Invalid or missing search_workflow in service_config; "
"using default SearchWorkflowConfig. Error: %s",
"*" if LogManager.is_sensitive() else e,
exc_info=not LogManager.is_sensitive(),
)
self.search_config = SearchWorkflowConfig()
self.per_question_params = self.agent_config.search_workflow_per_question_params
self.time_limit = int(self.per_question_params.time_limit)
os.environ["WORKFLOW_EXECUTE_TIMEOUT"] = str(self.time_limit)
logger.info(f"[DeepSearchAgent] per_question_params: {self.per_question_params}")
llm_configs = session_agent_config.llm_config
if LlmConfigCategory.GENERAL.value not in llm_configs:
raise CustomValueException(
error_code=StatusCode.LLM_CONFIG_NONE.code, message=StatusCode.LLM_CONFIG_NONE.errmsg
)
all_llms = {}
for _, llm_config in llm_configs.items():
llm_obj = create_llm_obj(llm_config)
all_llms[llm_config.model_name] = llm_obj
general_cfg = llm_configs[LlmConfigCategory.GENERAL.value]
init_llm_map = self.search_config.init_state_agent.llm_config
init_llm = init_llm_map.get("general") if init_llm_map else None
if init_llm is not None and init_llm.model_name and init_llm.model_name != general_cfg.model_name:
base_llm = all_llms.get(general_cfg.model_name)
if base_llm is not None:
all_llms[init_llm.model_name] = {
"model": base_llm["model"],
"model_name": init_llm.model_name,
}
logger.info(
"[DeepSearchAgent] registered extra LLM for init_state_agent: %s",
init_llm.model_name,
)
llm_token = llm_context.set(all_llms)
tool_class: list[Any] = []
if self.per_question_params.tool_map == "search_fetch":
tool_class.append(WebFetch({"jina_api_key": self.agent_config.jina_api_key}))
tool_class.append(WebSearch({"serper_api_key": self.agent_config.serper_api_key}))
zero_secret(self.agent_config.jina_api_key)
zero_secret(self.agent_config.serper_api_key)
elif self.per_question_params.tool_map == "retrieve":
milvus_cfg = self.agent_config.search_workflow_milvus_config
tool_class.append(_build_retrieve_tool(milvus_cfg))
zero_secret(milvus_cfg.embedder_api_key)
else:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=f"Invalid tool map: {self.per_question_params.tool_map}"
),
)
self.tool_map = {tool.name: tool for tool in tool_class}
tool_token = tool_context.set(self.tool_map)
self.query = message
self.gold_answer = gold_answer
self._build_agent()
result: SearchFinalResult = await self._run_internal()
if hasattr(result, "model_dump"):
yield json.dumps(to_json_safe(result.model_dump()), ensure_ascii=False)
elif isinstance(result, dict):
yield json.dumps(to_json_safe(result), ensure_ascii=False)
else:
yield json.dumps({"result": str(result)}, ensure_ascii=False)
finally:
if llm_token is not None:
llm_context.reset(llm_token)
if tool_token is not None:
tool_context.reset(tool_token)
class SimpleReactSearchAgent(BaseAgent):
_cached_system_prompt: str | None = None
@classmethod
def _load_system_prompt(cls) -> str:
if cls._cached_system_prompt is None:
path = (
Path(__file__).resolve().parents[3] / "algorithm" / "prompts" / "simple_react_search.md"
)
cls._cached_system_prompt = path.read_text(encoding="utf-8").strip()
return cls._cached_system_prompt
async def run(
self,
message: str,
conversation_id: str,
agent_config: dict,
*,
report_template: str = "",
interrupt_feedback: str = "",
service_config: Optional[dict] = None,
gold_answer: str | None = None,
) -> AsyncGenerator[str, None]:
validate_run_agent_params(
message, conversation_id, report_template, interrupt_feedback
)
validate_agent_required_field(agent_config)
_parsed_agent_cfg = AgentConfig.model_validate(agent_config)
session_agent_config = _parsed_agent_cfg.model_copy(deep=True)
try:
search_config = SearchWorkflowConfig.model_validate(
(service_config or {}).get("search_workflow", {})
)
except Exception:
search_config = SearchWorkflowConfig()
general = session_agent_config.llm_config.get("general")
if general is None:
raise CustomValueException(
error_code=StatusCode.LLM_CONFIG_NONE.code,
message=StatusCode.LLM_CONFIG_NONE.errmsg,
)
llm_registry = {general.model_name: create_llm_obj(general.model_copy(deep=True))}
llm_token = llm_context.set(llm_registry)
try:
per_question_params: PerQuestionParams = (
session_agent_config.search_workflow_per_question_params
)
if per_question_params.tool_map == "search_fetch":
tool_class = [
WebFetch({"jina_api_key": session_agent_config.jina_api_key}),
WebSearch({"serper_api_key": session_agent_config.serper_api_key}),
]
zero_secret(session_agent_config.jina_api_key)
zero_secret(session_agent_config.serper_api_key)
elif per_question_params.tool_map == "retrieve":
milvus_cfg = session_agent_config.search_workflow_milvus_config
tool_class = [_build_retrieve_tool(milvus_cfg)]
zero_secret(milvus_cfg.embedder_api_key)
else:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=f"Invalid tool map: {per_question_params.tool_map}"
),
)
tool_map = {tool.name: tool for tool in tool_class}
retrieval_only = per_question_params.tool_map == "retrieve"
tools_list = get_tool_definitions(retrieval_tool_only=retrieval_only)
sc_dict = to_dict_safe(search_config.state_creation_agent) or {}
retrieval_settings = to_dict_safe(sc_dict.get("retrieval_settings") or {})
tool_exec_config_base = dict(sc_dict)
base_log_dir = LogManager.get_log_dir() or "./output/logs"
log_dir = ensure_safe_directory(
os.path.join(base_log_dir, f"result_{conversation_id}"),
base_log_dir,
)
emit(
"react_run_started",
{
"conversation_id": conversation_id,
"tool_map": {k: v.__class__.__name__ for k, v in tool_map.items()},
"model_name": general.model_name,
"log_dir": "***" if LogManager.is_sensitive() else log_dir,
},
source="workflow.SimpleReactSearchAgent",
action_id=None,
)
max_steps = 1000
llm_invoke_cfg = {
"model_name": general.model_name,
"max_tries": getattr(general, "max_tries", 4),
"append_think_tags_to_messages": getattr(
general, "append_think_tags_to_messages", False
),
}
messages: list[dict[str, Any]] = [
{"role": "system", "content": self._load_system_prompt()},
{"role": "user", "content": message},
]
start_time = time.time()
total_in_tok = 0
total_out_tok = 0
prediction: str | None = None
new_found_evidence_ids: list[Any] = []
for _step in range(max_steps):
try:
raw, _reasoning, in_tok, out_tok = await _run_llm_via_ainvoke(
messages=messages,
config=llm_invoke_cfg,
agent_name=AgentLlmName.SIMPLE_REACT_SEARCH.value,
tools=tools_list,
)
except Exception as e:
messages.append(
{
"role": "assistant",
"content": f"Error: {e}",
}
)
break
total_in_tok += int(in_tok or 0)
total_out_tok += int(out_tok or 0)
if isinstance(raw, dict):
resp_content = raw.get("content") or ""
tool_calls = raw.get("tool_calls") or []
else:
resp_content = raw if isinstance(raw, str) else str(raw)
tool_calls = []
n_tools = len(tool_calls) if isinstance(tool_calls, list) else 0
tool_names: list[str] = []
if isinstance(tool_calls, list) and not LogManager.is_sensitive():
for tc in tool_calls:
if not isinstance(tc, dict):
continue
nm = tc.get("name") or (
(tc.get("function") or {}).get("name") if isinstance(tc.get("function"), dict) else None
)
if nm:
tool_names.append(str(nm))
emit(
"react_llm_turn",
{
"conversation_id": conversation_id,
"step": _step,
"tool_call_count": n_tools,
"tool_names": tool_names[:16],
"finished_after_turn": n_tools == 0,
},
source="workflow.SimpleReactSearchAgent",
action_id=None,
)
if not tool_calls:
prediction = (resp_content or "").strip() or None
messages.append(
{
"role": "assistant",
"content": resp_content or "",
}
)
break
messages.append(
{
"role": "assistant",
"content": resp_content or "",
"tool_calls": tool_calls,
}
)
for tc in tool_calls:
parsed, err = _parse_one_native_tool_call(tc)
if err or not parsed:
raw_id = tc.get("id") if isinstance(tc, dict) else None
tid = raw_id or str(uuid.uuid4().hex[:24])
nm = (
(tc.get("name") if isinstance(tc, dict) else None)
or (tc.get("function") or {}).get("name")
if isinstance(tc, dict)
else None
) or "unknown"
messages.append(
{
"role": "tool",
"tool_call_id": tid,
"name": nm,
"content": f"Tool call error: {err or 'parse failed'}",
}
)
continue
tool_name = parsed["name"]
tool_args = dict(parsed["arguments"])
call_id = parsed.get("tool_call_id") or str(uuid.uuid4().hex[:24])
try:
tool_result, new_found_evidence_ids = await execute_tool(
ExecuteToolConfig(
tool_map=tool_map,
tool_name=tool_name,
tool_args=tool_args,
config=tool_exec_config_base,
retrieval_settings=retrieval_settings,
action={},
new_found_evidence_ids=new_found_evidence_ids,
)
)
content = format_tool_result_for_message(tool_result)
except CustomValueException as e:
content = str(e)
except Exception as e:
content = f"Tool execution error: {e}"
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": tool_name,
"content": content,
}
)
emit_messages_updated(
source="workflow.SimpleReactSearchAgent",
messages=messages,
action_id=None,
extra={
"conversation_id": conversation_id,
"phase": "final",
"agent": "simple_react_search",
},
)
result = _save_and_return_search_final_result(
SaveSearchFinalResultConfig(
question=message,
messages=messages,
prediction=prediction or "No answer found",
gold_answer=gold_answer,
termination=Termination.ANSWER if prediction else Termination.FAIL_LIMIT,
retrieved_evidence_ids=new_found_evidence_ids,
params={
"total_input_tokens": total_in_tok,
"total_output_tokens": total_out_tok,
"start_time": start_time,
"log_dir": log_dir,
},
config={"agent": "simple_react_search"},
)
)
yield json.dumps(to_json_safe(result.model_dump()), ensure_ascii=False)
finally:
llm_context.reset(llm_token)
def parse_endnode_content(chunk: CustomSchema) -> dict | None:
"""
解析 EndNode 返回的content, 返回可能得exception_info
仅处理 agent == NodeId.END.value 且content非 "ALL END" 的情况。
Args:
chunk (CustomSchema): 流式输出的chunk
Returns:
dict: 如果解析到异常信息,返回 {"exception_info": ...},否则返回 空
"""
if isinstance(chunk, CustomSchema):
chunk = chunk.model_dump()
elif isinstance(chunk, dict):
chunk = chunk
else:
return {}
if chunk.get("agent", None) != NodeId.END.value:
return {}
content = chunk.get("content", "")
if not content or content == "ALL END" or content == "SECTION END":
return {}
try:
parsed_result = json.loads(content)
if isinstance(parsed_result, dict) and "exception_info" in parsed_result:
return parsed_result
return {}
except json.JSONDecodeError:
logger.debug("[DeepResearchAgent.run] EndNode returned non-JSON content.")
return {}
except Exception as parse_err:
if not LogManager.is_sensitive():
logger.warning(f"[DeepResearchAgent.run] Failed to parse endnode content: {parse_err}")
else:
logger.warning(f"[DeepResearchAgent.run] Failed to parse endnode content.")
return {}