import asyncio
import json
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel
from pymilvus import AnnSearchRequest, MilvusClient, WeightedRanker
from pymilvus.client.search_result import SearchResult
from openjiuwen.core.foundation.store.base_embedding import EmbeddingConfig
from openjiuwen.core.retrieval.common.config import (
KnowledgeBaseConfig,
RetrievalConfig,
VectorStoreConfig,
)
from openjiuwen.core.retrieval.knowledge_base import KnowledgeBase
from openjiuwen.core.retrieval.simple_knowledge_base import SimpleKnowledgeBase
from openjiuwen.core.retrieval.vector_store.milvus_store import MilvusVectorStore
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.base_retriever import (
MilvusBaseRetriever,
RetrieveConfig,
)
from openjiuwen_deepsearch.algorithm.search_tools.retrieval.embedder import (
AbstractEmbedder,
OpenJiuwenAPIEmbedder,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
class BrowsecompPlusHitMainBody(BaseModel):
start_idx: Optional[int] = None
class BrowsecompPlusHitSource(BaseModel):
prefix: Optional[str] = None
docid: Optional[str] = None
chunk_idx: Optional[int] = None
main_body: Optional[BrowsecompPlusHitMainBody] = None
class BrowsecompPlusHitMetadata(BaseModel):
title: Optional[str] = None
authors: Optional[List[str]] = None
source: Optional[BrowsecompPlusHitSource] = None
class BrowsecompPlusHit(BaseModel):
id: str
content: str
docid: Optional[str] = None
title: Optional[str] = None
metadata: Optional[BrowsecompPlusHitMetadata] = None
score: Optional[float] = None
class BrowsecompPlusCombinedHits(BaseModel):
content_list: List[str]
id_list: List[str]
snippet: str
doc_id: Optional[str] = None
title: Optional[str] = None
metadata: Optional[BrowsecompPlusHitMetadata] = None
score: Optional[float] = None
class BrowsecompPlusDoc(BaseModel):
query: str
results: List[BrowsecompPlusCombinedHits]
class BrowsecompPlusMilvusRetriever(MilvusBaseRetriever):
def _merge_hits_from_same_doc(
self,
hits_list: List[BrowsecompPlusHit],
content_splitter: str = " [...] ",
) -> BrowsecompPlusCombinedHits:
"""
Function for merging BrowseCompPlusHit's coming from the same `docid` and returning the resulting `snippet` in
a single hit entry-dict.
"""
score = max([hit.score for hit in hits_list])
doc_id = hits_list[0].docid
title = hits_list[0].title
metadata = hits_list[0].metadata
hits_list = sorted(
hits_list,
key=lambda x: x.metadata.source.chunk_idx,
reverse=False,
)
if hits_list[0].metadata.source.chunk_idx != 0:
expr = f" id == '{doc_id}__0'"
res = self.client.query(
collection_name=self.collection_name,
filter=expr,
output_fields=["content"],
limit=1,
)
new_hit_content = res[0]["content"]
hits_list.insert(
0, BrowsecompPlusHit(id=f"{doc_id}__0", content=new_hit_content)
)
id_list = [hit.id for hit in hits_list]
content_list = [hit.content for hit in hits_list]
snippet_list = []
for idx, hit in enumerate(hits_list):
if idx > 0:
start_idx = hit.metadata.source.main_body.start_idx
snippet_list.append(hit.content[start_idx:])
else:
snippet_list.append(hit.content)
snippet = content_splitter.join(snippet_list)
mergedhits = BrowsecompPlusCombinedHits(
doc_id=doc_id,
title=title,
metadata=metadata,
score=score,
content_list=content_list,
snippet=snippet,
id_list=id_list,
)
return mergedhits
def _return_unique(
self, hits: List[BrowsecompPlusHit], k: int
) -> List[BrowsecompPlusCombinedHits]:
"""
Retains only a list of unique `docid`s from the returned list of hits.
Due to chunking `hits` may include more than one chunks from the same `docid`.
"""
unique_docids: Dict[str, List[BrowsecompPlusHit]] = {}
for hit in hits:
docid = hit.docid
if not docid:
docid = str(hit.get("id", "")).split("__")[0]
if docid not in unique_docids:
unique_docids[docid] = [hit]
else:
unique_docids[docid].append(hit)
best_doc_ids = sorted(
unique_docids.keys(),
key=lambda x: max([hit.score for hit in unique_docids[x]]),
reverse=True,
)
best_doc_ids = best_doc_ids[:k]
unique_docids = {docid: unique_docids[docid] for docid in best_doc_ids}
merged_hits = []
for docid in unique_docids:
merged_hits.append(self._merge_hits_from_same_doc(unique_docids[docid]))
return merged_hits
def retrieve(
self,
retrieve_config: RetrieveConfig,
) -> Tuple[str, List[str]]:
query_vec = self.embedder.encode(
retrieve_config.query, is_query=retrieve_config.add_instruction
)
search_params = {
"metric_type": self.metric_type,
"params": {"ef": 500},
}
search_limit = retrieve_config.top_k * retrieve_config.top_k_multiply_factor
if retrieve_config.mode == "dense":
results: SearchResult = self.client.search(
collection_name=self.collection_name,
data=query_vec,
anns_field=self.vector_field,
search_params=search_params,
limit=search_limit,
output_fields=[
self.text_field,
self.title_field,
self.id_field,
"docid",
"metadata",
],
)
elif retrieve_config.mode == "hybrid":
search_param_1 = {
"data": query_vec,
"anns_field": self.vector_field,
"limit": search_limit,
"param": {"ef": 500},
}
request_1 = AnnSearchRequest(**search_param_1)
search_param_2 = {
"data": retrieve_config.query,
"anns_field": self.sparse_field,
"limit": search_limit,
"param": {},
}
request_2 = AnnSearchRequest(**search_param_2)
reqs = [request_1, request_2]
results = self.client.hybrid_search(
collection_name=self.collection_name,
reqs=reqs,
ranker=WeightedRanker(0.6, 0.4),
limit=search_limit,
output_fields=[
self.text_field,
self.title_field,
self.id_field,
"docid",
"metadata",
],
)
docs: List[BrowsecompPlusDoc] = []
if len(results) != len(retrieve_config.query):
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=f"Results length ({len(results)}) does not match query length "
f"({len(retrieve_config.query)})"
),
)
for i, result in enumerate(results):
hits = []
for hit in result:
hits.append(
BrowsecompPlusHit(
id=hit.id,
docid=hit.entity.get("docid"),
title=hit.entity.get(self.title_field),
content=hit.entity.get(self.text_field),
metadata=BrowsecompPlusHitMetadata(
**hit.entity.get("metadata")
),
score=hit.distance,
)
)
unique_hits = self._return_unique(hits, retrieve_config.top_k)
docs.append(
BrowsecompPlusDoc(query=retrieve_config.query[i], results=unique_hits)
)
if retrieve_config.save_as:
with open(retrieve_config.save_as, "w", encoding="utf-8") as f:
json.dump(
[doc.model_dump() for doc in docs], f, ensure_ascii=False, indent=2
)
to_return = "\n\n".join(
[
json.dumps(
{
"query": doc.query,
"results": [result.snippet for result in doc.results],
},
ensure_ascii=False,
)
for doc in docs
]
)
id_list = []
for doc in docs:
for result in doc.results:
id_list.extend(result.id_list)
return (
to_return,
id_list,
)
class KnowledgeBaseRetriever(MilvusBaseRetriever):
"""
Retrieves via ``KnowledgeBase.retrieve``.
Constructor matches ``MilvusBaseRetriever`` (same positional and optional field names).
Builds a ``SimpleKnowledgeBase`` from Milvus and embedder settings. Milvus connection and
collection checks are skipped (``connect_milvus=False``); ``self.client`` is ``None``.
"""
def __init__(
self,
milvus_host: str,
milvus_port: str,
database_name: str,
collection_name: str,
embedder: AbstractEmbedder,
vector_field: str = "embedding",
text_field: str = "content",
sparse_field: str = "content_sparse",
title_field: str = "title",
id_field: str = "id",
metric_type: str = "COSINE",
client: MilvusClient = None,
*,
kb_id: Optional[str] = None,
index_type: str = "vector",
milvus_token: str = "",
):
super().__init__(
milvus_host,
milvus_port,
database_name,
collection_name,
embedder,
vector_field=vector_field,
text_field=text_field,
sparse_field=sparse_field,
title_field=title_field,
id_field=id_field,
metric_type=metric_type,
client=client,
connect_milvus=False,
)
self._knowledge_base = self._create_knowledge_base(
milvus_host=milvus_host,
milvus_port=milvus_port,
database_name=database_name,
collection_name=collection_name,
embedder=embedder,
kb_id=kb_id or collection_name,
index_type=index_type,
milvus_token=milvus_token,
vector_field=vector_field,
text_field=text_field,
sparse_field=sparse_field,
)
@staticmethod
def _embed_model_from_embedder(embedder: AbstractEmbedder):
if isinstance(embedder, OpenJiuwenAPIEmbedder):
return embedder.embedder
token_str = embedder.api_token.decode("utf-8")
embed_config = EmbeddingConfig(
model_name=embedder.model_name,
base_url=embedder.api_url,
api_key=token_str,
)
embed_cls = OpenJiuwenAPIEmbedder.jiuwen_embedding_class_for_url(embedder.api_url)
return embed_cls(config=embed_config, timeout=embedder.timeout)
@classmethod
def _create_knowledge_base(
cls,
*,
milvus_host: str,
milvus_port: str,
database_name: str,
collection_name: str,
embedder: AbstractEmbedder,
kb_id: str,
index_type: str,
milvus_token: str,
vector_field: str,
text_field: str,
sparse_field: str,
) -> KnowledgeBase:
milvus_uri = f"http://{milvus_host}:{milvus_port}"
vs_config = VectorStoreConfig(
store_provider="milvus",
collection_name=collection_name,
database_name=database_name,
)
vector_store = MilvusVectorStore(
config=vs_config,
milvus_uri=milvus_uri,
milvus_token=milvus_token or None,
text_field=text_field,
vector_field=vector_field,
sparse_vector_field=sparse_field,
)
return SimpleKnowledgeBase(
config=KnowledgeBaseConfig(kb_id=kb_id, index_type=index_type),
vector_store=vector_store,
embed_model=cls._embed_model_from_embedder(embedder),
)
def _run_retrieve(self, query: str, retrieval_config: RetrievalConfig):
return asyncio.run(
self._knowledge_base.retrieve(query, config=retrieval_config)
)
@staticmethod
def _result_id(res) -> str:
if getattr(res, "chunk_id", None):
return str(res.chunk_id)
if getattr(res, "doc_id", None):
return str(res.doc_id)
meta = getattr(res, "metadata", None) or {}
for key in ("id", "chunk_id", "doc_id"):
if key in meta and meta[key] is not None:
return str(meta[key])
return ""
def retrieve(
self,
retrieve_config: RetrieveConfig,
) -> Tuple[str, List[str]]:
top_k = retrieve_config.top_k * max(1, retrieve_config.top_k_multiply_factor)
retrieval_config = RetrievalConfig(top_k=top_k)
docs = []
id_list: List[str] = []
for q in retrieve_config.query:
results = self._run_retrieve(q, retrieval_config)
entry = {"query": q, "results": [r.text for r in results]}
if retrieve_config.save_as:
entry["raw_results"] = [r.model_dump() for r in results]
docs.append(entry)
for r in results:
rid = self._result_id(r)
if rid:
id_list.append(rid)
if retrieve_config.save_as:
with open(retrieve_config.save_as, "w", encoding="utf-8") as f:
json.dump(docs, f, ensure_ascii=False, indent=2)
to_return = "\n\n".join(
json.dumps(
{"query": d["query"], "results": d["results"]},
ensure_ascii=False,
)
for d in docs
)
return to_return, id_list