import json
import logging
import os
from typing import TypedDict
import httpx
from langchain_core.tools import tool as make_tool, Tool
from langchain_core.runnables import RunnableConfig
from deepinsight.core.utils.research_utils import parse_research_config
from deepinsight.core.types.graph_config import RetrievalType
import requests
__all__ = ["KnowledgeTool"]
logger = logging.getLogger(__name__)
def _create_tool_description(f):
tool = make_tool(f, parse_docstring=True)
return dict(description=tool.description, args_schema=tool.args_schema)
class KnowledgeTool:
"""A langchain Knowledge tool to access knowledge base of RagFlow."""
@staticmethod
async def async_knowledge_retrieve(question: str, config: RunnableConfig):
"""Async version of `KnowledgeTool.sync_knowledge_retrieve`."""
logger.info(f"开始执行知识检索流程,待检索的问题: {question}")
api_base = _get_api_base(question)
try:
async with httpx.AsyncClient() as client:
response = await client.post(**_make_request_args(question, api_base, config))
return _handle_response(response)
except Exception as e:
_log_exception(e, question)
raise
@staticmethod
def sync_knowledge_retrieve(question: str, config: RunnableConfig):
"""
RAG流程核心检索工具:根据输入问题,从指定知识库中精准提取高相关性知识片段,
为后续回答生成提供时效性、准确性、领域针对性的事实支撑,解决LLM知识过时、事实偏差问题。
适用场景:
- 领域专属问答(如法律条文查询、医疗指南解读、企业产品手册咨询)
- 时效性问题检索(如最新行业数据、政策文件、赛事结果)
- 长文档关键信息提取(如学术论文结论、白皮书核心观点)
- 多轮对话上下文补充(关联历史提问的知识溯源与扩展)
Args:
question: str,必选参数
- 功能:待检索知识的问题,支持单个问题检索
- 约束:问题需为完整表意字符串,需包含关键实体(如"2024年 新能源汽车")、
明确限定词(如"同比增长 监管政策"),避免模糊表述(如"这个怎么操作?")
"""
logger.info(f"开始执行知识检索流程,待检索的问题: {question}")
api_base = _get_api_base(question)
try:
response = requests.post(**_make_request_args(question, api_base, config))
return _handle_response(response)
except Exception as e:
_log_exception(e, question)
raise
knowledge_retrieve = Tool.from_function(func=sync_knowledge_retrieve, name="knowledge_retrieve",
coroutine=async_knowledge_retrieve,
**_create_tool_description(sync_knowledge_retrieve))
def _get_api_base(question: str) -> str:
api_base = os.environ.get("RAGFLOW_API_BASE")
if not api_base:
logging.error(f"[EnvironError] 未配置RagFlow环境,检索终止 | Query: {question}")
raise ValueError("RagFlow host information is not configured. Retrieval terminated.")
return api_base
def _make_request_args(question: str, api_base: str, config: RunnableConfig) -> dict:
rc = parse_research_config(config)
retrieval_config = rc.retrieval_config
if not retrieval_config or RetrievalType.RAGFLOW not in retrieval_config:
raise ValueError("RagFlow retrieval config is not configured.")
ragflow_retrieval_config = retrieval_config[RetrievalType.RAGFLOW]
logger.info(f"对话ID: {ragflow_retrieval_config.args.dialog_id or '未知'}, 知识库IDs: {ragflow_retrieval_config.args.kb_ids}")
kbs = ragflow_retrieval_config.args.kb_ids
if not isinstance(kbs, list) and kbs:
logger.error(f"未找到与对话 {ragflow_retrieval_config.args.dialog_id!r} 关联的知识库")
raise RuntimeError(f"No knowledge bases found for dialog {ragflow_retrieval_config.args.dialog_id!r}")
similarity_threshold = ragflow_retrieval_config.args.similarity_threshold
rerank_enabled = bool(ragflow_retrieval_config.args.rerank_id)
logger.info(f"调用检索器进行知识检索,{len(kbs)}个知识库,top_n={ragflow_retrieval_config.args.top_n},"
f"相似度阈值={similarity_threshold},"
f"{'' if rerank_enabled else '未'}启用重排。")
headers = {"Authorization": f"Bearer {ragflow_retrieval_config.api_key}"}
params = dict(
question=question,
dataset_ids=kbs,
document_ids=[],
page=1,
page_size=20,
similarity_threshold=similarity_threshold,
vector_similarity_weight=ragflow_retrieval_config.args.vector_similarity_weight,
top_k=ragflow_retrieval_config.args.top_k or 1024,
rerank_id=ragflow_retrieval_config.args.rerank_id,
keyword=False
)
return dict(url=f"{api_base}/retrieval", headers=headers, json=params,
timeout=60 if rerank_enabled else 30)
def _handle_response(response: httpx.Response | requests.Response) -> str:
response.raise_for_status()
response_body: dict = response.json()
if not ((response_body.get("code") == 0) and
isinstance(response_body.get("data"), dict) and
isinstance(response_body["data"].get("chunks"), list)):
raise RuntimeError(response_body.get("message") or f"连接到知识库时出现未知问题。{response_body=}")
raw_chunks: list[dict] = response_body["data"]["chunks"]
if len(raw_chunks):
returns = [
{
"title": each.get("content"),
"url": each.get("document_id"),
"chunk_id": each.get("id"),
"content_with_weight": each.get("content"),
"doc_id": each.get("document_id"),
"docnm_kwd": each.get("document_keyword"),
"kb_id": each.get("dataset_id"),
"image_id": each.get("image_id"),
"similarity": each.get("similarity"),
"positions": each.get("positions"),
}
for each in raw_chunks
]
else:
logger.warning("未检索到任何知识片段")
returns = []
return json.dumps(returns, indent=4, ensure_ascii=False)
def _log_exception(e: Exception, question: str) -> None:
if isinstance(e, (httpx.ConnectTimeout, requests.ConnectTimeout)):
logging.error(f"[TimeoutError] 搜索请求超时: 连接或读取超时 - {e} | Query: {question}")
elif isinstance(e, (httpx.ConnectError, requests.ConnectionError)):
logging.error(f"[NetworkError] 网络连接失败: 无法连接到RagFlow服务 - {e} | Query: {question}")
elif isinstance(e, (httpx.HTTPError, requests.HTTPError)):
logging.error(f"[HTTPError] 未知HTTP错误: {e} | Query: {question}")
logging.error(f"[UnknownError] 搜索处理过程中发生未知错误: {e} | Query: {question}")