import asyncio
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
from openjiuwen_deepsearch.algorithm.search_nodes.run_action import resolve_native_tool_call_name
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
def format_tool_result_for_message(tool_result: Any) -> str:
if isinstance(tool_result, str):
return tool_result
try:
return json.dumps(tool_result, ensure_ascii=False, default=str)
except Exception:
return str(tool_result)
@dataclass
class ExecuteToolConfig:
tool_map: Dict[str, Any]
tool_name: str
tool_args: Dict[str, Any]
config: Dict[str, Any]
retrieval_settings: Dict[str, Any]
action: Dict[str, Any]
new_found_evidence_ids: List[Any]
async def _invoke_registered_tool(tool_map: dict, canonical_name: str, tool_args: dict):
if canonical_name not in tool_map:
raise CustomValueException(
StatusCode.LOAD_EXTEND_TOOLS_FAILED.code,
StatusCode.LOAD_EXTEND_TOOLS_FAILED.errmsg.format(tool_name=canonical_name),
)
tool = tool_map[canonical_name]
if hasattr(tool, "acall"):
return await tool.acall(tool_args)
return await asyncio.to_thread(tool.call, tool_args)
async def execute_tool(execute_config: ExecuteToolConfig) -> Tuple[Any, List[Any]]:
retrieval_only = "retrieve" in execute_config.tool_map and "web_search" not in execute_config.tool_map
canonical = resolve_native_tool_call_name(execute_config.tool_name, retrieval_only)
if not canonical or canonical not in execute_config.tool_map:
allowed = ", ".join(sorted(execute_config.tool_map.keys()))
raise CustomValueException(
StatusCode.TOOL_EXEC_ERROR.code,
StatusCode.TOOL_EXEC_ERROR.errmsg.format(
e=f"Tool '{execute_config.tool_name}' is not allowed. Allowed tools: {allowed}."
),
)
tool_args = dict(execute_config.tool_args)
if canonical == "web_fetch":
tool_args["log_fetch"] = execute_config.config.get("log_fetch", False)
tool_args["fetch_tool_model"] = (
execute_config.config.get("llm_config", {}).get("general", {}).get("model_name", None)
)
elif canonical == "web_search":
tool_args["log_search"] = execute_config.config.get("log_search", True)
elif canonical == "retrieve":
tool_args["top_k"] = execute_config.retrieval_settings.get("top_k", 5)
tool_args["add_instruction"] = execute_config.retrieval_settings.get("add_instruction", True)
tool_args["mode"] = execute_config.retrieval_settings.get("mode", "dense")
tool_args["top_k_multiply_factor"] = execute_config.retrieval_settings.get("top_k_multiply_factor", 10)
try:
tool_result = await _invoke_registered_tool(execute_config.tool_map, canonical, tool_args)
except Exception as e:
raise CustomValueException(
StatusCode.TOOL_EXEC_ERROR.code,
StatusCode.TOOL_EXEC_ERROR.errmsg.format(e=e),
) from e
if canonical == "web_fetch":
if "url" not in tool_args or "goal" not in tool_args:
logger.error(
"[tool_node] fetch tool_args missing expected keys. "
"Present keys: %s. tool_args: %s. tool_result: %s",
list(tool_args.keys()),
tool_args,
tool_result,
)
execute_config.new_found_evidence_ids.append(
{
"url": tool_args.get("url"),
"goal": tool_args.get("goal"),
}
)
if canonical == "retrieve":
results, id_list = tool_result
action_state = execute_config.action.get("state", {}) or {}
existing_ids = set(action_state.get("retrieved_evidence_ids", []))
for id_ in id_list:
if id_ not in execute_config.new_found_evidence_ids and id_ not in existing_ids:
execute_config.new_found_evidence_ids.append(id_)
return results, execute_config.new_found_evidence_ids
return tool_result, execute_config.new_found_evidence_ids