from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
from typing import List, Optional, Tuple
from pymilvus import MilvusClient
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.embedder import AbstractEmbedder
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
logger = logging.getLogger(__name__)
@dataclass
class RetrieveConfig:
query: List[str]
top_k: int = 5
add_instruction: bool = False
mode: str = "dense"
top_k_multiply_factor: int = 1
save_as: Optional[str] = None
class BaseRetriever(ABC):
"""Abstract retriever: maps ``RetrieveConfig`` to aggregated text and id lists."""
@abstractmethod
def retrieve(
self,
retrieve_config: RetrieveConfig,
) -> Tuple[str, List[str]]:
pass
class MilvusBaseRetriever(BaseRetriever):
def __init__(
self,
milvus_host: str,
milvus_port: str,
database_name: str,
collection_name: str,
embedder: AbstractEmbedder,
vector_field: str = "embedding",
text_field: str = "content",
sparse_field: str = "content_sparse",
title_field: str = "title",
id_field: str = "id",
metric_type: str = "COSINE",
client: MilvusClient = None,
connect_milvus: bool = True,
):
self.embedder = embedder
self.vector_field = vector_field
self.text_field = text_field
self.title_field = title_field
self.id_field = id_field
self.metric_type = metric_type
self.collection_name = collection_name
self.database_name = database_name
self.sparse_field = sparse_field
if not connect_milvus:
self.client = client
return
self.client = client or MilvusClient(
uri=f"http://{milvus_host}:{milvus_port}",
)
if database_name and database_name != "default":
if database_name not in self.client.list_databases():
error_msg = f"[RETRIEVER] Milvus database '{database_name}' does not exist."
logger.error(error_msg)
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(e=error_msg))
self.client.use_database(database_name)
if collection_name not in self.client.list_collections():
error_msg = f"[RETRIEVER] Milvus collection '{collection_name}' does not exist."
logger.error(error_msg)
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(e=error_msg))
self.client.load_collection(collection_name)