"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
from __future__ import annotations
from typing import List, Dict, Optional, Any, Union
import numpy as np
from loguru import logger
from pymilvus import MilvusClient, DataType
from pymilvus.client.types import ExtraList
from mx_rag.storage.vectorstore.vectorstore import VectorStore, SearchMode
from mx_rag.utils.common import validate_params, MAX_VEC_DIM, MAX_TOP_K, validate_embeddings, \
_check_sparse_and_dense, BOOL_TYPE_CHECK_TIP, MAX_IDS_SIZE, validate_sequence
from mx_rag.utils.common import validate_list_str, validate_embeddings
class MilvusError(Exception):
pass
class SchemaBuilder:
def __init__(self, auto_id=False):
self.schema = MilvusClient.create_schema(auto_id=auto_id, enable_dynamic_field=True)
self._add_base_fields()
def add_vector_field(self, field_name: str, datatype: DataType, dim: int):
self.schema.add_field(field_name=field_name, datatype=datatype, dim=dim)
def add_sparse_vector_field(self, field_name: str, datatype: DataType):
self.schema.add_field(field_name=field_name, datatype=datatype)
def build(self):
return self.schema
def _add_base_fields(self):
self.schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
class IndexParamsBuilder:
def __init__(self, client, search_mode: SearchMode):
self.client = client
self.search_mode = search_mode
self.index_params = self.client.prepare_index_params()
def add_sparse_index(self, params: Dict[str, Any]):
self.index_params.add_index(
field_name="sparse_vector",
index_name="sparse_index",
index_type="SPARSE_INVERTED_INDEX",
metric_type="IP",
params=params.get("sparse", {}),
)
def add_dense_index(self, index_type, metric_type, params: Dict[str, Any]):
self.index_params.add_index(
field_name="vector",
index_name="dense_index",
index_type=index_type,
metric_type=metric_type,
params=params.get("dense", {}),
)
def build(self):
return self.index_params
def prepare_index_params(self, index_type, metric_type, params: Dict[str, Any]):
if self.search_mode == SearchMode.DENSE:
self.add_dense_index(index_type, metric_type, params)
elif self.search_mode == SearchMode.SPARSE:
self.add_sparse_index(params)
else:
self.add_sparse_index(params)
self.add_dense_index(index_type, metric_type, params)
return self.build()
class MilvusDB(VectorStore):
MAX_COLLECTION_NAME_LENGTH = 1024
MAX_URL_LENGTH = 1024
MAX_QUERY_LENGTH = 1000
MAX_DICT_LENGTH = 100000
SCALE_MAP = {
"IP": lambda x: min(x, 1.0),
"L2": lambda x: max(1.0 - x / 2.0, 0.0),
"COSINE": lambda x: min(x, 1.0)
}
@validate_params(
client=dict(validator=lambda x: isinstance(x, MilvusClient),
message="param must be instance of MilvusClient"),
collection_name=dict(
validator=lambda x: isinstance(x, str) and 0 < len(x) <= MilvusDB.MAX_COLLECTION_NAME_LENGTH,
message="param must be str and length range (0, 1024]"),
search_mode=dict(validator=lambda x: isinstance(x, SearchMode),
message="param must be instance of SearchMode"),
auto_id=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
index_type=dict(validator=lambda x: isinstance(x, str) and x in ("FLAT", "IVF_FLAT", "IVF_PQ", "HNSW"),
message="param must str and one of [FLAT, IVF_FLAT, IVF_PQ, HNSW]"),
metric_type=dict(validator=lambda x: isinstance(x, str) and x in ("IP", "L2", "COSINE"),
message="param must str and one of [IP, L2, COSINE]"),
auto_flush=dict(validator=lambda x: isinstance(x, bool), message=BOOL_TYPE_CHECK_TIP),
)
def __init__(self, client: MilvusClient, collection_name: str = "rag_sdk",
search_mode: SearchMode = SearchMode.DENSE, auto_id=False,
index_type: str = "FLAT", metric_type: str = "L2", auto_flush=True):
super().__init__()
self._client = client
self._collection_name = collection_name
self._search_mode = search_mode
self._auto_id = auto_id
self._index_type = index_type
self._metric_type = metric_type
self._filter_dict = None
self._auto_flush = auto_flush
self.score_scale = self.SCALE_MAP.get(self._metric_type)
@property
def search_mode(self):
return self._search_mode
@property
def collection_name(self):
return self._collection_name
@property
def client(self):
return self._client
@staticmethod
def create(**kwargs):
client_field = "client"
if client_field not in kwargs or not isinstance(kwargs.get(client_field), MilvusClient):
logger.error(f"param error: {client_field} must be specified")
return None
client = kwargs.pop("client", None)
x_dim = kwargs.pop("x_dim", None)
index_type = kwargs.pop("index_type", "FLAT")
if index_type not in ("FLAT", "IVF_FLAT", "IVF_PQ", "HNSW"):
logger.error("param error: index_type must be one of [FLAT, IVF_FLAT, IVF_PQ, HNSW]")
return None
metric_type = kwargs.pop("metric_type", "L2")
if metric_type not in ("IP", "L2", "COSINE"):
logger.error("param error: metric_type must be one of [IP, L2, COSINE]")
return None
params = kwargs.pop("params", {})
if not isinstance(params, dict):
logger.error("param error: params must be dict. ")
return None
collection_name = kwargs.pop("collection_name", "rag_sdk")
search_mode = kwargs.pop("search_mode", SearchMode.DENSE)
auto_id = kwargs.pop("auto_id", False)
auto_flush = kwargs.pop("auto_flush", True)
milvus_db = MilvusDB(client, collection_name=collection_name, search_mode=search_mode,
auto_id=auto_id, index_type=index_type, metric_type=metric_type,
auto_flush=auto_flush)
try:
milvus_db.create_collection(x_dim=x_dim, params=params)
except KeyError:
logger.error("milvus create collection meet key error")
milvus_db = None
except Exception as e:
logger.error(f"milvus create collection failed: {e}")
milvus_db = None
return milvus_db
@validate_params(collection_name=dict(validator=lambda x: 0 < len(x) <= MilvusDB.MAX_COLLECTION_NAME_LENGTH,
message="param length range (0, 1024]"))
def set_collection_name(self, collection_name: str):
self._collection_name = collection_name
@validate_params(
x_dim=dict(validator=lambda x: x is None or (isinstance(x, int) and 0 < x <= MAX_VEC_DIM),
message="param value range (0, 1024 * 1024]"),
params=dict(validator=lambda x: x is None or (isinstance(x, dict) and validate_sequence(x, max_check_depth=2)),
message="params requires to be None or dict")
)
def create_collection(self, x_dim: Optional[int] = None, params=None):
if self._client.has_collection(self._collection_name):
logger.warning(f"Collection {self._collection_name} already exists")
return
if (self.search_mode == SearchMode.DENSE or self.search_mode == SearchMode.HYBRID) and x_dim is None:
raise MilvusError("x_dim can't be None in mode DENSE or HYBRID")
if params is None:
params = {}
schema = self._create_schema(x_dim)
index_params = self._prepare_index_params(params)
self.client.create_collection(
collection_name=self._collection_name,
schema=schema,
index_params=index_params
)
def drop_collection(self):
if not self.client.has_collection(self._collection_name):
logger.warning(f"collection {self._collection_name} does not existed")
else:
self.client.drop_collection(self._collection_name)
@validate_params(
ids=dict(validator=lambda x: all(isinstance(it, int) for it in x) and 0 <= len(x) < MAX_IDS_SIZE,
message="param must be List[int]")
)
def delete(self, ids: List[int]):
if len(ids) == 0:
logger.warning("no id need be deleted")
return 0
self._validate_collection_existence()
if len(ids) >= self.MAX_VEC_NUM:
raise MilvusError(f"Length of ids is over limit, {len(ids)} >= {self.MAX_VEC_NUM}")
res = self.client.delete(collection_name=self._collection_name, ids=ids)
if isinstance(res, dict):
res = res.get("delete_count")
if self._auto_flush:
self.flush()
logger.info(f"success delete {len(ids)} vectors in MilvusDB.")
return res
@validate_params(
embeddings=dict(validator=lambda x: validate_embeddings(x)[0],
message="param must be Union[List[List[float]], List[Dict[int, float]]]"),
k=dict(validator=lambda x: isinstance(x, int) and 0 < x <= MAX_TOP_K, message="param length range (0, 10000]"),
filter_dict=dict(validator=lambda x: isinstance(x, dict) or x is None,
message="param filter_dict must be dict or None"))
def search(self, embeddings: Union[List[List[float]], List[Dict[int, float]]],
k: int = 3, filter_dict=None, **kwargs):
self._filter_dict = filter_dict
self._validate_collection_existence()
output_fields = kwargs.pop("output_fields", [])
if isinstance(embeddings, list) and all(isinstance(x, dict) for x in embeddings):
return self._perform_sparse_search(embeddings, k, output_fields, **kwargs)
else:
return self._perform_dense_search(np.array(embeddings), k, output_fields, **kwargs)
@validate_params(
ids=dict(validator=lambda x: all(isinstance(it, int) for it in x) and 0 <= len(x) < MAX_IDS_SIZE,
message="param must be List[int]"),
document_id=dict(validator=lambda x: isinstance(x, int) and x >= 0,
message="param must greater equal than 0")
)
def add(self, ids: List[int], embeddings: np.ndarray, document_id: int = 0, docs: Optional[List[str]] = None,
metadatas: Optional[List[Dict]] = None):
"""往向量数据库添加稠密向量,仅适用于稠密模式
"""
self._validate_collection_existence()
if self.search_mode != SearchMode.DENSE:
raise MilvusError("search mode needs to be DENSE")
data = self._init_insert_data(ids, docs, metadatas, document_id)
self._handle_dense_input(embeddings, ids, data)
self.client.insert(collection_name=self._collection_name, data=data)
if self._auto_flush:
self.flush()
logger.info(f"success add {len(ids)} ids in MilvusDB.")
@validate_params(
ids=dict(validator=lambda x: all(isinstance(it, int) for it in x) and 0 <= len(x) < MAX_IDS_SIZE,
message="param must be List[int]"),
document_id=dict(validator=lambda x: isinstance(x, int) and x >= 0,
message="param must greater equal than 0")
)
def add_sparse(self, ids, sparse_embeddings, document_id: int = 0, docs: Optional[List[str]] = None,
metadatas: Optional[List[Dict]] = None):
self._validate_collection_existence()
if self.search_mode != SearchMode.SPARSE:
raise MilvusError("search mode must be SPARSE")
data = self._init_insert_data(ids, docs, metadatas, document_id)
self._handle_sparse_input(sparse_embeddings, ids, data)
self.client.insert(collection_name=self._collection_name, data=data)
if self._auto_flush:
self.flush()
logger.info(f"successfully add {len(ids)} vectors in MilvusDB.")
@validate_params(
ids=dict(validator=lambda x: all(isinstance(it, int) for it in x) and 0 <= len(x) < MAX_IDS_SIZE,
message="param must be List[int]")
)
def add_dense_and_sparse(self, ids: List[int], dense_embeddings: np.ndarray,
sparse_embeddings: List[Dict[int, float]], docs: Optional[List[str]] = None,
metadatas: Optional[List[Dict]] = None, **kwargs):
self._validate_collection_existence()
if self.search_mode != SearchMode.HYBRID:
raise MilvusError("search mode must be HYBRID")
document_id = kwargs.pop("document_id", 0)
if not isinstance(document_id, int):
raise MilvusError("param document_id must be int")
data = self._init_insert_data(ids, docs, metadatas, document_id=document_id)
self._handle_sparse_input(sparse_embeddings, ids, data)
self._handle_dense_input(dense_embeddings, ids, data)
self.client.insert(collection_name=self._collection_name, data=data)
if self._auto_flush:
self.flush()
logger.info(f"successfully add {len(ids)} vectors in MilvusDB.")
def get_all_ids(self) -> List[int]:
all_id = self.client.query(self._collection_name, filter="id == 0 or id != 0", output_fields=["id"])
ids = [idx['id'] for idx in all_id]
return ids
@validate_params(
ids=dict(validator=lambda x: all(isinstance(it, int) for it in x) and 0 <= len(x) < MAX_IDS_SIZE,
message="param must be List[int]"),
dense=dict(validator=lambda x: x is None or isinstance(x, np.ndarray),
message="dense must be Optional[np.ndarray]"),
sparse=dict(validator=lambda x: x is None or validate_embeddings(x)[0],
message="sparse must to be Optional[List[Dict[int, float]]]")
)
def update(self, ids: List[int], dense: Optional[np.ndarray] = None,
sparse: Optional[List[Dict[int, float]]] = None):
_check_sparse_and_dense(ids, dense, sparse)
responses = self.client.get(
collection_name=self.collection_name,
ids=ids
)
if len(responses) != len(ids):
queried_ids = [res.get("id") for res in responses]
raise MilvusError(f"the input id {set(ids) - set(queried_ids)} in ids not exists in milvus")
if dense is None:
dense = [None] * len(ids)
if sparse is None:
sparse = [None] * len(ids)
for response in responses:
dense_vector = dense[ids.index(response.get("id"))]
sparse_vector = sparse[ids.index(response.get("id"))]
if dense_vector is not None:
response["vector"] = dense_vector
if sparse_vector is not None:
response["sparse_vector"] = sparse_vector
if responses:
self.client.upsert(collection_name=self.collection_name, data=responses)
if self._auto_flush:
self.flush()
logger.info(f"Successfully updated chunk ids {ids}")
@validate_params(collection_name=dict(validator=lambda x: 0 < len(x) <= MilvusDB.MAX_COLLECTION_NAME_LENGTH,
message="param length range (0, 1024]"))
def has_collection(self, collection_name):
return self.client.has_collection(collection_name)
def flush(self):
self.client.refresh_load(collection_name=self.collection_name)
def _create_schema_dense(self, x_dim):
builder = SchemaBuilder(self._auto_id)
builder.add_vector_field("vector", DataType.FLOAT_VECTOR, x_dim)
return builder.build()
def _create_schema_sparse(self):
builder = SchemaBuilder(self._auto_id)
builder.add_sparse_vector_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)
return builder.build()
def _create_schema_hybrid(self, x_dim):
builder = SchemaBuilder(self._auto_id)
builder.add_vector_field("vector", DataType.FLOAT_VECTOR, x_dim)
builder.add_sparse_vector_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)
return builder.build()
def _create_schema(self, x_dim: Optional[int] = None):
if self.search_mode == SearchMode.DENSE:
return self._create_schema_dense(x_dim)
if self.search_mode == SearchMode.SPARSE:
return self._create_schema_sparse()
return self._create_schema_hybrid(x_dim)
def _prepare_index_params(self, params):
builder = IndexParamsBuilder(self.client, self.search_mode)
return builder.prepare_index_params(self._index_type, self._metric_type, params)
def _init_insert_data(self, ids, docs, metadatas, document_id):
data = [{"id": i, "document_id": document_id} for i in ids]
if docs is not None:
self._validate_docs(docs)
if len(ids) != len(docs):
raise MilvusError("#id must be equal #doc")
for i, doc in enumerate(docs):
data[i]["document"] = doc
if metadatas is not None:
self._validate_metadatas(metadatas)
if len(ids) != len(metadatas):
raise MilvusError("#id must be equal #metadata")
for i, metadata in enumerate(metadatas):
data[i]["metadata"] = metadata
return data
def _validate_collection_existence(self):
if not self.client.has_collection(self._collection_name):
raise MilvusError(f"collection {self._collection_name} is not existed")
def _validate_dense_input(self, data, search=True):
if not isinstance(data, np.ndarray):
raise ValueError("param must be np.ndarray")
if len(data.shape) != 2:
raise ValueError("shape of embedding must equal to 2")
limit = self.MAX_SEARCH_BATCH if search else self.MAX_VEC_NUM
if data.shape[0] >= limit:
raise ValueError(f"num of embeddings must less {limit}")
def _validate_sparse_input(self, data):
if not (isinstance(data, list) and all(isinstance(x, dict) for x in data)):
raise ValueError(
f"param must be List[Dict] with max length {self.MAX_DICT_LENGTH}"
)
def _validate_docs(self, data):
ret = validate_list_str(
data,
[1, self.MAX_SEARCH_BATCH],
[1, self.MAX_QUERY_LENGTH]
)
if not ret:
raise ValueError(
f"param must be List[str] with max length {self.MAX_SEARCH_BATCH} "
f"and each string length in [1, {self.MAX_QUERY_LENGTH}]"
)
def _validate_metadatas(self, data):
if not (isinstance(data, list) and all(isinstance(it, dict) for it in data)):
raise MilvusError("param error: param must be list[dict]")
if len(data) > self.MAX_SEARCH_BATCH:
raise MilvusError(f"param error: length of list must be less or equal {self.MAX_SEARCH_BATCH}")
def _handle_dense_input(self, embeddings: Optional[np.ndarray], ids: List[int], data: List[Dict]):
self._validate_dense_input(embeddings, search=False)
if embeddings.shape[0] != len(ids):
raise MilvusError("Length of embeddings is not equal to number of ids")
for i, e in enumerate(embeddings.astype(np.float32)):
data[i]["vector"] = e
def _handle_sparse_input(self, sparse_embeddings: Optional[List[Dict[int, float]]],
ids: List[int], data: List[Dict]):
self._validate_sparse_input(sparse_embeddings)
if len(sparse_embeddings) != len(ids):
raise MilvusError("Length of sparse_embeddings is not equal to number of ids")
for i, x in enumerate(sparse_embeddings):
data[i]["sparse_vector"] = x
def _perform_dense_search(self, embeddings: np.ndarray, k: int, output_fields: list, **kwargs):
"""Handle dense search logic."""
if self.search_mode not in (SearchMode.DENSE, SearchMode.HYBRID):
raise ValueError("Sparse search only supports DENSE or HYBRID mode")
self._validate_dense_input(embeddings)
embeddings = embeddings.astype(np.float32)
self._validate_filter_dict(self._filter_dict)
doc_filter = self._filter_dict.get("document_id", []) if self._filter_dict else []
search_kwargs = {
"collection_name": self._collection_name,
"anns_field": "vector",
"limit": k,
"data": embeddings,
"output_fields": output_fields
}
if doc_filter:
search_kwargs["filter"] = "document_id IN {document_list}"
search_kwargs["filter_params"] = {"document_list": doc_filter}
res = self.client.search(**search_kwargs, **kwargs)
return self._process_search_results(res, output_fields)
def _perform_sparse_search(self, sparse_embeddings: List[Dict[int, float]], k: int, output_fields: list, **kwargs):
"""Handle sparse search logic."""
if self.search_mode not in (SearchMode.SPARSE, SearchMode.HYBRID):
raise ValueError("Sparse search only supports SPARSE or HYBRID mode")
self._validate_sparse_input(sparse_embeddings)
self._validate_filter_dict(self._filter_dict)
doc_filter = self._filter_dict.get("document_id", []) if self._filter_dict else []
if doc_filter:
res = self.client.search(
collection_name=self._collection_name,
anns_field="sparse_vector",
limit=k,
data=sparse_embeddings,
output_fields=output_fields,
filter="document_id IN {document_list}",
filter_params={"document_list": doc_filter},
**kwargs
)
else:
res = self.client.search(
collection_name=self._collection_name,
anns_field="sparse_vector",
limit=k,
data=sparse_embeddings,
output_fields=output_fields,
**kwargs
)
return self._process_search_results(res, output_fields)
def _process_search_results(self, data: ExtraList, output_fields: Optional[List[str]] = None):
if output_fields is None:
output_fields = []
filtered_fields = [field for field in output_fields if field not in ["id", "distance"]]
scores, ids, extra_data = [], [], []
for top_k in data:
k_score, k_id = [], []
k_extra = [[] for _ in filtered_fields]
for entity in top_k:
k_score.append(entity["distance"])
k_id.append(entity["id"])
self._append_search_outfields(entity, filtered_fields, k_extra)
scores.append(k_score)
ids.append(k_id)
extra_data.append(k_extra)
return self._score_scale(scores), ids, extra_data
def _append_search_outfields(self, entity, filtered_fields, k_extra):
for idx, field in enumerate(filtered_fields):
k_extra_value = entity["entity"].get(field, None)
if k_extra_value is not None:
k_extra[idx].append({field: k_extra_value})