import os
import sys
from typing import List, Tuple, Type
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.base_retriever import (
BaseRetriever,
RetrieveConfig,
)
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.embedder import (
OpenJiuwenAPIEmbedder,
)
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.retriever import (
KnowledgeBaseRetriever,
)
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
class Retrieve:
name = "retrieve"
description = ""
def __init__(self, retriever: BaseRetriever):
self.retriever = retriever
def call(self, params: dict) -> Tuple[str, List[str]]:
"""
params: dict = {
"query": List[str],
"top_k": int,
"add_instruction": bool,
"mode": str,
"top_k_multiply_factor": int,
"save_as": None | str
}
"""
return self.retriever.retrieve(
RetrieveConfig(
query=params["query"],
top_k=params["top_k"],
add_instruction=params["add_instruction"],
mode=params["mode"],
top_k_multiply_factor=params["top_k_multiply_factor"],
save_as=params.get("save_as", None),
)
)
class RetrieveTool(Retrieve):
name = "retrieve"
description = "Retrieve evidence paragraphs from the knowledge base via Milvus."
def __init__(
self,
config: dict,
retriever_class: Type[BaseRetriever] = KnowledgeBaseRetriever,
):
milvus_host = config.get("milvus_host")
milvus_port = config.get("milvus_port")
database_name = config.get("database_name")
collection_name = config.get("collection_name")
embedder_model_name = config.get("embedder_model_name")
embedder_api_key = config.get("embedder_api_key")
embedder_base_url = config.get("embedder_base_url")
embedder_timeout = config.get("embedder_timeout")
embedder = OpenJiuwenAPIEmbedder(
pretrained_model=embedder_model_name,
api_token=embedder_api_key,
api_url=embedder_base_url,
timeout=embedder_timeout,
)
retriever = retriever_class(
milvus_host=milvus_host or "localhost",
milvus_port=str(milvus_port) if milvus_port else "19530",
database_name=database_name or "deepsearch_benchmarks",
collection_name=collection_name or "browsecompplus_with_bm25",
embedder=embedder,
)
super().__init__(retriever)