import logging
import json
from typing import Callable, Awaitable
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain.tools.tool_node import ToolCallRequest
from langchain_core.messages import MessageLikeRepresentation, filter_messages, ToolMessage
from langchain.agents.middleware import AgentMiddleware
from deepinsight.core.types.research import ResearchComplete, think_tool
from deepinsight.core.types.graph_config import SearchAPI, RetrievalType
from deepinsight.core.tools.tavily_search import tavily_search
from deepinsight.core.tools.ragflow_retrival import KnowledgeTool
from deepinsight.core.utils.mcp_utils import MCPClientUtils
from deepinsight.core.utils.research_utils import parse_research_config
from deepinsight.service.rag.engine import RAGEngine
def create_retrieval_tool(retrieval_type: RetrievalType, config: RunnableConfig):
"""Factory function to create retrieval tools based on retrieval type.
Args:
retrieval_type: The type of retrieval engine (RAGFLOW, LLAMAINDEX, or LIGHTRAG)
config: Runtime configuration containing retrieval configs
Returns:
LangChain Tool instance for the specified retrieval type
Raises:
ValueError: If the retrieval type is not supported or config is missing
"""
rc = parse_research_config(config)
retrieval_config = rc.retrieval_config
if not retrieval_config or retrieval_type not in retrieval_config:
raise ValueError(f"{retrieval_type} retrieval config is not configured.")
if retrieval_type == RetrievalType.RAGFLOW:
return KnowledgeTool.knowledge_retrieve
if retrieval_type in [RetrievalType.LLAMAINDEX, RetrievalType.LIGHTRAG]:
engine = RAGEngine.from_retrieval_config(retrieval_config[retrieval_type])
return engine.as_tool(retrieval_config[retrieval_type])
raise ValueError(f"Unsupported retrieval type: {retrieval_type}")
async def get_search_tools(search_apis: list[SearchAPI], config: RunnableConfig = None):
"""Configure and return search tools based on the specified API providers.
Args:
search_apis: List of search API providers to use (Anthropic, OpenAI, Tavily, etc.)
Returns:
Combined list of configured search tool objects for all specified providers
"""
tools = []
for api in search_apis:
if api == SearchAPI.ANTHROPIC:
tools.append({
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 5
})
elif api == SearchAPI.GEMINI:
tools.append(GenAITool(google_search={}))
elif api == SearchAPI.OPENAI:
tools.append({"type": "web_search_preview"})
elif api == SearchAPI.TAVILY:
search_tool = tavily_search
search_tool.metadata = {
**(search_tool.metadata or {}),
"type": "search",
"name": "web_search"
}
tools.append(search_tool)
elif api == SearchAPI.PAPER_STATIC_DATA:
query_tools = await MCPClientUtils.get_tools(
tools_name_list=["get_institution_stats", "get_proceedings_keyword_frequency",
"get_author_coauthorship",
"generate_bar_chart"],
server_name="conference-static")
tools.extend(query_tools)
elif api == SearchAPI.RAG_RETRIVAL:
rc = parse_research_config(config)
retrieval_config = rc.retrieval_config
for retrieval_type in retrieval_config.keys():
try:
retrieval_tool = create_retrieval_tool(retrieval_type, config)
tools.append(retrieval_tool)
except ValueError as e:
logging.warning(f"Failed to create retrieval tool for {retrieval_type}: {e}")
return tools
async def get_all_tools(config: RunnableConfig):
"""Assemble complete toolkit including research, search, and MCP tools.
Args:
config: Runtime configuration specifying search API and MCP settings
Returns:
List of all configured and available tools for research operations
"""
tools = [tool(ResearchComplete), think_tool]
rc = parse_research_config(config)
search_apis = [SearchAPI(value) for value in rc.search_api]
search_tools = await get_search_tools(search_apis, config)
tools.extend(search_tools)
if hasattr(rc, "tools") and rc.tools:
tools.extend(rc.tools)
return tools
def get_notes_from_tool_calls(messages: list[MessageLikeRepresentation]):
"""Extract notes from tool call messages."""
return [tool_msg.content for tool_msg in filter_messages(messages, include_types="tool")]
def anthropic_websearch_called(response):
"""Detect if Anthropic's native web search was used in the response.
Args:
response: The response object from Anthropic's API
Returns:
True if web search was called, False otherwise
"""
try:
usage = response.response_metadata.get("usage")
if not usage:
return False
server_tool_use = usage.get("server_tool_use")
if not server_tool_use:
return False
web_search_requests = server_tool_use.get("web_search_requests")
if web_search_requests is None:
return False
return web_search_requests > 0
except (AttributeError, TypeError):
return False
def openai_websearch_called(response):
"""Detect if OpenAI's web search functionality was used in the response.
Args:
response: The response object from OpenAI's API
Returns:
True if web search was called, False otherwise
"""
tool_outputs = response.additional_kwargs.get("tool_outputs")
if not tool_outputs:
return False
for tool_output in tool_outputs:
if tool_output.get("type") == "web_search_call":
return True
return False
class CoerceToolOutput(AgentMiddleware):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage],
) -> ToolMessage:
result = handler(request)
if isinstance(request.tool_call.get("args", {}).get("messages"), list):
for message in request.tool_call["args"]["messages"]:
if not isinstance(message.get("content"), str):
if isinstance(message["content"], (dict, list)):
message["content"] = json.dumps(message["content"], ensure_ascii=False)
else:
message["content"] = str(message["content"])
if isinstance(result, ToolMessage):
if isinstance(result.content, dict):
messages = result.content.get("messages", [])
else:
messages = []
for message in messages:
if not isinstance(message.content, str):
if isinstance(message.content, (dict, list)):
message.content = json.dumps(message.content, ensure_ascii=False)
else:
message.content = str(message.content)
if not isinstance(result.content, str):
try:
result.content = json.dumps(result.content, ensure_ascii=False)
except TypeError:
result.content = str(result.content)
return result
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage]],
) -> ToolMessage:
result = await handler(request)
if isinstance(request.tool_call.get("args", {}).get("messages"), list):
for message in request.tool_call["args"]["messages"]:
if not isinstance(message.get("content"), str):
if isinstance(message["content"], (dict, list)):
message["content"] = json.dumps(message["content"], ensure_ascii=False)
else:
message["content"] = str(message["content"])
if isinstance(result, ToolMessage):
if isinstance(result.content, dict):
messages = result.content.get("messages", [])
else:
messages = []
for message in messages:
if not isinstance(message.content, str):
if isinstance(message.content, (dict, list)):
message.content = json.dumps(message.content, ensure_ascii=False)
else:
message.content = str(message.content)
if not isinstance(result.content, str):
try:
result.content = json.dumps(result.content, ensure_ascii=False)
except TypeError:
result.content = str(result.content)
return result