"""内存引擎 - 核心 API"""
import asyncio
import logging
import os
import uuid
from pathlib import Path
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from ..storage.vector_db import OpenGaussVectorDB, VectorRecord
from .document_processor import DocumentProcessor, TextChunk
from ..embeddings.base import EmbeddingBase, EmbeddingFactory
from ..embeddings import openai
@dataclass
class SearchResult:
"""搜索结果"""
text: str
score: float
metadata: Dict[str, Any]
source: str
id: str = ""
class MemoryEngine:
"""内存引擎 - 主要API接口
提供文档索引、语义搜索等功能
"""
def __init__(
self,
db_host: str = "localhost",
db_port: int = 5432,
db_name: str = "memory_db",
db_user: str = "postgres",
db_password: str = "",
db_sslmode: Optional[str] = None,
db_gssencmode: Optional[str] = None,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-3-small",
chunk_size: int = 1000,
chunk_overlap: int = 200,
vector_index_type: str = "ivfflat",
index_lists: int = 100,
index_m: int = 16,
index_ef_construction: int = 64,
):
db_host = os.getenv("OG_DB_HOST", db_host)
db_port = int(os.environ.get("OG_DB_PORT", str(db_port)))
db_name = os.getenv("OG_DB_NAME", db_name)
db_user = os.getenv("OG_DB_USER", db_user)
db_password = os.getenv("OG_DB_PASSWORD", db_password)
embedding_provider = os.getenv("OG_EMBEDDING_PROVIDER", embedding_provider)
embedding_model = (
os.getenv("OG_EMBEDDING_MODEL")
or os.getenv("OPENAI_EMBEDDING_MODEL")
or embedding_model
)
self.db = OpenGaussVectorDB(
host=db_host,
port=db_port,
database=db_name,
user=db_user,
password=db_password,
sslmode=db_sslmode,
gssencmode=db_gssencmode,
dimension=self._get_dimension(embedding_provider, embedding_model),
index_type=vector_index_type,
lists=index_lists,
m=index_m,
ef_construction=index_ef_construction,
)
self.embedder = EmbeddingFactory.create(
embedding_provider, model=embedding_model
)
self.processor = DocumentProcessor(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
self.logger = logging.getLogger(__name__)
def _get_dimension(self, provider: str, model: str) -> int:
"""获取嵌入维度"""
embedder = EmbeddingFactory.create(provider, model=model)
return embedder.get_dimension()
async def index_file(self, file_path: str) -> int:
"""索引单个文件
Args:
file_path: 文件路径
Returns:
索引的块数量
"""
self.logger.info(f"Indexing file: {file_path}")
text = self.processor.load_document(file_path)
chunks = self.processor.process(text, source=file_path)
if not chunks:
self.logger.warning(f"No chunks generated from: {file_path}")
return 0
vectors = await self.embedder.encode([c.content for c in chunks])
records = []
for chunk, vector in zip(chunks, vectors):
metadata = {"source": chunk.source, **chunk.metadata}
records.append(
VectorRecord(
id=chunk.chunk_id,
vector=vector,
metadata=metadata,
text=chunk.content,
)
)
count = self.db.insert(records)
self.logger.info(f"Indexed {count} chunks from {file_path}")
return count
async def index_directory(self, directory: str) -> int:
"""索引目录中的所有Markdown文件
Args:
directory: 目录路径
Returns:
索引的块总数
"""
self.logger.info(f"Indexing directory: {directory}")
path = Path(directory)
if not path.exists():
raise FileNotFoundError(f"Directory not found: {directory}")
if not path.is_dir():
raise ValueError(f"Path is not a directory: {directory}")
total_count = 0
md_files = list(path.glob("**/*.md")) + list(path.glob("**/*.markdown"))
for md_file in md_files:
try:
count = await self.index_file(str(md_file))
total_count += count
except Exception as e:
self.logger.error(f"Failed to index {md_file}: {e}")
self.logger.info(f"Total indexed chunks: {total_count}")
return total_count
async def search(
self,
query: str,
limit: int = 10,
filters: Optional[Dict] = None,
use_hybrid: bool = True,
) -> List[SearchResult]:
"""语义搜索
Args:
query: 搜索查询文本
limit: 返回结果数量限制
filters: 元数据过滤条件
use_hybrid: 是否使用混合搜索(向量+BM25)
Returns:
搜索结果列表
"""
self.logger.info(f"Searching: {query}")
query_vector = await self.embedder.encode_single(query)
results = self.db.search(
query_vector,
query_text=query,
limit=limit,
filters=filters,
use_hybrid=use_hybrid,
)
search_results = []
for result in results:
search_results.append(
SearchResult(
text=result["text"],
score=result["score"],
metadata=result["metadata"],
source=result["metadata"].get("source", ""),
id=result.get("id", ""),
)
)
return search_results
async def add_memory(self, text: str, metadata: Optional[Dict] = None) -> str:
"""添加单条记忆
Args:
text: 记忆文本
metadata: 元数据
Returns:
记录ID
"""
if metadata is None:
metadata = {}
vector = await self.embedder.encode_single(text)
record_id = str(uuid.uuid4())
record = VectorRecord(
id=record_id,
vector=vector,
metadata=metadata,
text=text,
)
self.db.insert([record])
return record_id
def delete_by_source(self, source: str) -> int:
"""删除指定源的所有记录
Args:
source: 源文件路径
Returns:
删除的记录数
"""
self.logger.info(f"Deleting records from source: {source}")
conn = self.db._get_conn()
cursor = conn.cursor()
cursor.execute(
f"""
SELECT id FROM {self.db.table_name}
WHERE metadata->>'source' = %s
""",
(source,),
)
ids = [row[0] for row in cursor.fetchall()]
cursor.close()
if ids:
return self.db.delete(ids)
return 0
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息
Returns:
统计信息字典
"""
return {
"total_records": self.db.count(),
"embedding_model": self.embedder.get_model_name(),
"embedding_dimension": self.embedder.get_dimension(),
}
def clear_all(self):
"""清空所有数据"""
self.logger.warning("Clearing all data")
self.db.clear()
def close(self):
"""关闭连接"""
self.db.close()
self.logger.info("Memory engine closed")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()