"""openGauss 向量存储后端"""
import json
import logging
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
try:
from psycopg.errors import UniqueViolation
except ImportError:
UniqueViolation = None
@dataclass
class VectorRecord:
"""向量记录数据结构"""
id: str
vector: List[float]
metadata: Dict[str, Any]
text: str
class VectorDatabase:
"""openGauss 向量数据库接口"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self._connection = None
def connect(self):
"""建立数据库连接"""
raise NotImplementedError
def insert(self, records: List[VectorRecord]) -> int:
"""插入向量记录"""
raise NotImplementedError
def search(
self,
query_vector: List[float],
query_text: str = "",
limit: int = 10,
filters: Optional[Dict] = None,
use_hybrid: bool = False,
) -> List[Dict[str, Any]]:
"""向量相似度搜索"""
raise NotImplementedError
def delete(self, record_ids: List[str]) -> int:
"""删除记录"""
raise NotImplementedError
def close(self):
"""关闭数据库连接"""
raise NotImplementedError
class OpenGaussVectorDB(VectorDatabase):
"""openGauss 向量数据库实现"""
def __init__(
self,
host: str = "localhost",
port: int = 5432,
database: str = "memory_db",
user: str = "postgres",
password: str = "",
table_name: str = "vectors",
dimension: int = 1536,
index_type: str = "ivfflat",
lists: int = 100,
m: int = 16,
ef_construction: int = 64,
sslmode: Optional[str] = None,
gssencmode: Optional[str] = None,
bm25_parallel_workers: Optional[int] = None,
):
self.config = {
"host": host,
"port": port,
"database": database,
"user": user,
"password": password,
"sslmode": sslmode,
"gssencmode": gssencmode,
}
self.table_name = table_name
self.dimension = dimension
self.index_type = index_type
self.lists = lists
self.m = m
self.ef_construction = ef_construction
self.bm25_parallel_workers = (
max(1, min(32, bm25_parallel_workers))
if bm25_parallel_workers is not None
else None
)
self._initialize_schema()
def connect(self):
"""建立 openGauss 连接(支持 GaussDB/openGauss 常用参数 sslmode、gssencmode)"""
try:
import psycopg
parts = [
f"host={self.config['host']}",
f"port={self.config['port']}",
f"dbname={self.config['database']}",
f"user={self.config['user']}",
f"password={self.config['password']}",
]
if self.config.get("sslmode") is not None:
parts.append(f"sslmode={self.config['sslmode']}")
if self.config.get("gssencmode") is not None:
parts.append(f"gssencmode={self.config['gssencmode']}")
conninfo = " ".join(parts)
self._connection = psycopg.connect(conninfo=conninfo)
logger.info("Connected to openGauss")
return self._connection
except ImportError:
raise ImportError("Please install: pip install psycopg")
def _initialize_schema(self):
"""初始化数据库表结构;若表已存在但向量维度不一致则 DROP 后重建"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(
f"""
SELECT EXISTS (
SELECT 1 FROM pg_tables
WHERE tablename = '{self.table_name}'
)
"""
)
exists = cursor.fetchone()[0]
if exists:
cursor.execute(
"""
SELECT a.atttypmod FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
WHERE c.relname = %s AND a.attname = 'vector' AND NOT a.attisdropped
""",
(self.table_name,),
)
row = cursor.fetchone()
if row and row[0] is not None and row[0] != self.dimension:
cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
conn.commit()
exists = False
else:
cursor.close()
return
if not exists:
self._create_table(cursor)
conn.commit()
logger.info(f"Created table: {self.table_name}")
cursor.close()
def _create_table(self, cursor):
"""创建向量表和索引(openGauss 7+ 使用内置 vector 类型,无需单独 extension)"""
cursor.execute(
f"""
CREATE TABLE {self.table_name} (
id VARCHAR(64) PRIMARY KEY,
vector vector({self.dimension}),
text_content TEXT,
metadata JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
if self.index_type == "ivfflat":
cursor.execute(
f"""
CREATE INDEX idx_{self.table_name}_vector
ON {self.table_name} USING ivfflat (vector vector_cosine_ops)
WITH (lists = {self.lists})
"""
)
elif self.index_type == "hnsw":
cursor.execute(
f"""
CREATE INDEX idx_{self.table_name}_vector
ON {self.table_name} USING hnsw (vector vector_cosine_ops)
WITH (m = {self.m}, ef_construction = {self.ef_construction})
"""
)
elif self.index_type == "diskann":
cursor.execute(
f"""
CREATE INDEX idx_{self.table_name}_vector
ON {self.table_name} USING diskann (vector vector_cosine_ops)
WITH (index_size = 50)
"""
)
if self.bm25_parallel_workers is not None:
cursor.execute(
f"ALTER TABLE {self.table_name} SET(parallel_workers={self.bm25_parallel_workers});"
)
cursor.execute(
f"""
CREATE INDEX idx_{self.table_name}_bm25
ON {self.table_name} USING bm25(text_content)
"""
)
def _get_conn(self):
"""获取或创建连接"""
if self._is_connection_closed():
self.connect()
return self._connection
def _is_connection_closed(self) -> bool:
"""检查连接是否关闭"""
try:
return self._connection is None or getattr(
self._connection, "closed", False
)
except (AttributeError, TypeError):
return True
def insert(self, records: List[VectorRecord]) -> int:
"""批量插入向量记录"""
if not records:
return 0
conn = self._get_conn()
cursor = conn.cursor()
count = 0
for record in records:
meta_json = json.dumps(record.metadata)
try:
cursor.execute(
f"""
INSERT INTO {self.table_name} (id, vector, text_content, metadata)
VALUES (%s, %s::vector, %s, %s)
""",
(record.id, record.vector, record.text, meta_json),
)
except Exception as e:
err_msg = str(e).lower()
if "unique" in err_msg or "duplicate" in err_msg or "26000" in err_msg:
conn.rollback()
cursor.execute(
f"""
UPDATE {self.table_name}
SET vector = %s::vector, text_content = %s, metadata = %s, updated_at = CURRENT_TIMESTAMP
WHERE id = %s
""",
(record.vector, record.text, meta_json, record.id),
)
else:
raise
count += 1
conn.commit()
cursor.close()
return count
def search(
self,
query_vector: List[float],
query_text: str = "",
limit: int = 10,
filters: Optional[Dict] = None,
use_hybrid: bool = True,
) -> List[Dict[str, Any]]:
"""混合搜索:向量相似度 + BM25全文检索"""
conn = self._get_conn()
cursor = conn.cursor()
if use_hybrid and query_text:
results = self._hybrid_search(
cursor, query_vector, query_text, limit, filters
)
else:
results = self._vector_search(cursor, query_vector, limit, filters)
cursor.close()
return results
def _vector_search(
self,
cursor,
query_vector: List[float],
limit: int,
filters: Optional[Dict],
) -> List[Dict[str, Any]]:
"""纯向量相似度搜索"""
query = f"""
SELECT id, text_content, metadata,
1 - (vector <=> %s::vector) as similarity
FROM {self.table_name}
"""
params: List[Any] = [query_vector]
if filters:
conditions = []
for key, value in filters.items():
conditions.append(f"metadata->>'{key}' = %s")
params.append(str(value))
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += f" ORDER BY vector <=> %s::vector LIMIT %s"
params.extend([query_vector, limit])
cursor.execute(query, params)
results: List[Dict[str, Any]] = []
for row in cursor.fetchall():
results.append(
{"id": row[0], "text": row[1], "metadata": row[2], "score": row[3]}
)
return results
def _hybrid_search(
self,
cursor,
query_vector: List[float],
query_text: str,
limit: int,
filters: Optional[Dict],
) -> List[Dict[str, Any]]:
"""混合搜索:结合向量和BM25"""
vector_limit = limit * 2
bm25_limit = limit * 2
vector_results = self._vector_search(
cursor, query_vector, vector_limit, filters
)
bm25_results = self._bm25_search(cursor, query_text, bm25_limit, filters)
combined_results = self._rrf_rerank(vector_results, bm25_results, limit)
return combined_results
def _bm25_search(
self,
cursor,
query_text: str,
limit: int,
filters: Optional[Dict],
) -> List[Dict[str, Any]]:
"""BM25全文检索(参照 openGauss BM25 使用指南:indexscan 提示 + <&> 操作符 + ORDER BY DESC)"""
bm25_index_name = f"idx_{self.table_name}_bm25"
query = f"""
SELECT /*+ indexscan ({self.table_name} {bm25_index_name}) */
id, text_content, metadata,
text_content <&> %s AS score
FROM {self.table_name}
"""
params: List[Any] = [query_text]
if filters:
conditions = []
for key, value in filters.items():
conditions.append(f"metadata->>'{key}' = %s")
params.append(str(value))
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY text_content <&> %s DESC LIMIT %s"
params.extend([query_text, limit])
cursor.execute(query, params)
results: List[Dict[str, Any]] = []
for row in cursor.fetchall():
results.append(
{"id": row[0], "text": row[1], "metadata": row[2], "score": row[3]}
)
return results
def _rrf_rerank(
self,
vector_results: List[Dict[str, Any]],
bm25_results: List[Dict[str, Any]],
limit: int,
) -> List[Dict[str, Any]]:
"""RRF (Reciprocal Rank Fusion) 重排序"""
k = 60
scores: Dict[str, float] = {}
for i, result in enumerate(vector_results, 1):
rrf_score = 1.0 / (k + i)
doc_id = result["id"]
if doc_id in scores:
scores[doc_id] += rrf_score
else:
scores[doc_id] = rrf_score
for i, result in enumerate(bm25_results, 1):
rrf_score = 1.0 / (k + i)
doc_id = result["id"]
if doc_id in scores:
scores[doc_id] += rrf_score
else:
scores[doc_id] = rrf_score
sorted_results = sorted(scores.items(), key=lambda x: x[1], reverse=True)
final_results = []
seen_ids = set()
by_id: Dict[str, Dict[str, Any]] = {r["id"]: r for r in vector_results}
for r in bm25_results:
if r["id"] not in by_id:
by_id[r["id"]] = r
for doc_id, rrf_score in sorted_results:
if doc_id in seen_ids:
continue
seen_ids.add(doc_id)
info = by_id.get(doc_id)
if info is None:
continue
final_results.append(
{
"id": doc_id,
"text": info.get("text", ""),
"metadata": info.get("metadata", {}),
"score": rrf_score,
}
)
if len(final_results) >= limit:
break
return final_results
def delete(self, record_ids: List[str]) -> int:
"""删除记录"""
if not record_ids:
return 0
conn = self._get_conn()
cursor = conn.cursor()
placeholders = ",".join(["%s"] * len(record_ids))
cursor.execute(
f"""
DELETE FROM {self.table_name}
WHERE id IN ({placeholders})
""",
record_ids,
)
count = cursor.rowcount
conn.commit()
cursor.close()
return count
def get_by_id(self, record_id: str) -> Optional[Dict]:
"""根据ID获取记录"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(
f"""
SELECT id, vector, text_content, metadata
FROM {self.table_name}
WHERE id = %s
""",
(record_id,),
)
row = cursor.fetchone()
cursor.close()
if row:
return {"id": row[0], "vector": row[1], "text": row[2], "metadata": row[3]}
return None
def count(self) -> int:
"""获取记录总数"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
count = cursor.fetchone()[0]
cursor.close()
return count
def clear(self):
"""清空所有数据(TRUNCATE,不 DROP,避免容器压力过大)"""
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(f"TRUNCATE TABLE {self.table_name}")
conn.commit()
cursor.close()
def close(self):
"""关闭数据库连接"""
if self._connection and not self._is_connection_closed():
self._connection.close()
logger.info("Database connection closed")