"""Embedding模型"""
import logging
from typing import Any
from pgvector.sqlalchemy import Vector
from sqlalchemy import BigInteger, Column, ForeignKey, Index, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase
from apps.common.postgres import postgres
from apps.models import Base, LLMData, LLMProvider
from .providers import BaseProvider, OpenAIProvider, TEIProvider
_logger = logging.getLogger(__name__)
_CLASS_DICT: dict[LLMProvider, type[BaseProvider]] = {
LLMProvider.OPENAI: OpenAIProvider,
LLMProvider.TEI: TEIProvider,
}
class VectorBase(DeclarativeBase):
"""向量表基类,共享主应用Base的metadata"""
metadata = Base.metadata
class VectorTableManager:
"""向量表管理器"""
_current_dim: int | None = None
_current_model_fingerprint: str | None = None
_tables_created: bool = False
VectorBase: Any = None
NodePoolVector: Any = None
FlowPoolVector: Any = None
ServicePoolVector: Any = None
MCPVector: Any = None
MCPToolVector: Any = None
async def _get_table_metadata(self) -> tuple[bool, int | None, str | None]:
"""从数据库读取表的元数据(维度和模型指纹)"""
async with postgres.session() as session:
try:
result = await session.execute(
text(
"SELECT "
" pg_get_expr(adbin, adrelid) as default_value, "
" attname "
"FROM pg_attribute "
"JOIN pg_attrdef ON adrelid = attrelid AND adnum = attnum "
"WHERE attrelid = 'framework_flow_vector'::regclass "
" AND attname IN ('embedding', 'model_fingerprint')",
),
)
rows = result.all()
existing_dim = None
existing_fingerprint = None
for row in rows:
if row[1] == "embedding":
default_str = row[0]
if default_str and default_str.startswith("'["):
vector_str = default_str.split("'")[1]
existing_dim = vector_str.count(",") + 1
elif row[1] == "model_fingerprint":
default_str = row[0]
if default_str and default_str.startswith("'"):
existing_fingerprint = default_str.strip("'").split("'")[0]
if existing_dim is not None:
return True, existing_dim, existing_fingerprint
except Exception as e:
_logger.debug("[VectorTableManager] 无法读取表元数据: %s", e)
return False, None, None
async def ensure_tables(self, dim: int, llm_config: LLMData) -> None:
"""确保向量表存在且维度和模型都正确"""
if dim <= 0:
err = "[VectorTableManager] 检测到的Embedding维度为0,无法创建Vector表"
_logger.error(err)
raise RuntimeError(err)
model_fingerprint = llm_config.modelName or ""
table_exists, existing_dim, existing_fingerprint = await self._get_table_metadata()
needs_rebuild = False
rebuild_reason = ""
if not table_exists:
needs_rebuild = True
rebuild_reason = "表不存在"
_logger.info("[VectorTableManager] 数据库中不存在向量表,将创建新表")
elif existing_dim != dim:
needs_rebuild = True
rebuild_reason = f"维度不匹配:数据库={existing_dim}, 需要={dim}"
_logger.warning("[VectorTableManager] %s", rebuild_reason)
elif existing_fingerprint and existing_fingerprint != model_fingerprint:
needs_rebuild = True
rebuild_reason = (
f"模型变化:{llm_config.provider}/{llm_config.modelName} "
f"(数据库指纹={existing_fingerprint}, 当前指纹={model_fingerprint})"
)
_logger.warning("[VectorTableManager] %s", rebuild_reason)
elif not existing_fingerprint:
needs_rebuild = True
rebuild_reason = "表中无数据或缺少模型指纹"
_logger.warning("[VectorTableManager] %s", rebuild_reason)
if not needs_rebuild:
_logger.debug(
"[VectorTableManager] 向量表已存在且有效,维度=%d,模型=%s/%s,指纹=%s",
dim,
llm_config.provider,
llm_config.modelName,
model_fingerprint,
)
self._define_orm_classes(dim, llm_config)
self._current_dim = dim
self._current_model_fingerprint = model_fingerprint
self._tables_created = True
return
_logger.warning("[VectorTableManager] 将删除旧表并重建。原因: %s", rebuild_reason)
if table_exists:
await self._drop_tables()
_logger.info(
"[VectorTableManager] 开始创建向量表,维度=%d,模型=%s/%s,指纹=%s",
dim,
llm_config.provider,
llm_config.modelName,
model_fingerprint,
)
await self._create_tables(dim, llm_config)
self._current_dim = dim
self._current_model_fingerprint = model_fingerprint
self._tables_created = True
_logger.info("[VectorTableManager] 向量表创建成功")
async def _drop_tables(self) -> None:
"""删除所有向量表"""
_logger.info("[VectorTableManager] 删除所有向量表")
async with postgres.session() as session:
await session.execute(text("DROP TABLE IF EXISTS framework_flow_vector"))
await session.execute(text("DROP TABLE IF EXISTS framework_service_vector"))
await session.execute(text("DROP TABLE IF EXISTS framework_node_vector"))
await session.execute(text("DROP TABLE IF EXISTS framework_mcp_vector"))
await session.execute(text("DROP TABLE IF EXISTS framework_mcp_tool_vector"))
await session.commit()
self._tables_created = False
def _define_orm_classes(self, dim: int, llm_config: LLMData) -> None:
"""定义ORM类,关联到数据库表"""
self.VectorBase = VectorBase
model_fingerprint = llm_config.modelName or ""
zero_vector = "[" + ",".join(["0"] * dim) + "]"
flow_pool_table_def = {
"__tablename__": "framework_flow_vector",
"appId": Column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False),
"id": Column(String(255), ForeignKey("framework_flow.id"), primary_key=True),
"embedding": Column(Vector(dim), nullable=False, server_default=zero_vector),
"model_fingerprint": Column(String(300), nullable=False, server_default=model_fingerprint),
"__table_args__": (
Index(
"flow_vector_hnsw_index",
"embedding",
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 200},
postgresql_ops={"embedding": "vector_cosine_ops"},
),
),
}
service_pool_table_def = {
"__tablename__": "framework_service_vector",
"id": Column(UUID(as_uuid=True), ForeignKey("framework_service.id"), primary_key=True),
"embedding": Column(Vector(dim), nullable=False, server_default=zero_vector),
"model_fingerprint": Column(String(300), nullable=False, server_default=model_fingerprint),
"__table_args__": (
Index(
"service_vector_hnsw_index",
"embedding",
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 200},
postgresql_ops={"embedding": "vector_cosine_ops"},
),
),
}
node_pool_table_def = {
"__tablename__": "framework_node_vector",
"id": Column(String(255), ForeignKey("framework_node.id"), primary_key=True),
"serviceId": Column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=True),
"embedding": Column(Vector(dim), nullable=False, server_default=zero_vector),
"model_fingerprint": Column(String(300), nullable=False, server_default=model_fingerprint),
"__table_args__": (
Index(
"node_vector_hnsw_index",
"embedding",
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 200},
postgresql_ops={"embedding": "vector_cosine_ops"},
),
),
}
mcp_table_def = {
"__tablename__": "framework_mcp_vector",
"id": Column(String(255), ForeignKey("framework_mcp.id"), primary_key=True),
"embedding": Column(Vector(dim), nullable=False, server_default=zero_vector),
"model_fingerprint": Column(String(300), nullable=False, server_default=model_fingerprint),
"__table_args__": (
Index(
"mcp_vector_hnsw_index",
"embedding",
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 200},
postgresql_ops={"embedding": "vector_cosine_ops"},
),
),
}
mcp_tool_table_def = {
"__tablename__": "framework_mcp_tool_vector",
"id": Column(BigInteger, ForeignKey("framework_mcp_tools.id"), primary_key=True),
"mcpId": Column(String(255), ForeignKey("framework_mcp.id"), nullable=False),
"embedding": Column(Vector(dim), nullable=False, server_default=zero_vector),
"model_fingerprint": Column(String(300), nullable=False, server_default=model_fingerprint),
"__table_args__": (
Index(
"mcp_tool_vector_hnsw_index",
"embedding",
postgresql_using="hnsw",
postgresql_with={"m": 16, "ef_construction": 200},
postgresql_ops={"embedding": "vector_cosine_ops"},
),
),
}
self.FlowPoolVector = type("FlowPoolVector", (self.VectorBase,), flow_pool_table_def)
self.ServicePoolVector = type("ServicePoolVector", (self.VectorBase,), service_pool_table_def)
self.NodePoolVector = type("NodePoolVector", (self.VectorBase,), node_pool_table_def)
self.MCPVector = type("MCPVector", (self.VectorBase,), mcp_table_def)
self.MCPToolVector = type("MCPToolVector", (self.VectorBase,), mcp_tool_table_def)
async def _create_tables(self, dim: int, llm_config: LLMData) -> None:
"""创建向量表"""
self._define_orm_classes(dim, llm_config)
async with postgres.engine.begin() as conn:
await conn.run_sync(self.VectorBase.metadata.create_all)
vector_table_manager = VectorTableManager()
class Embedding:
"""Embedding模型"""
_table_manager = vector_table_manager
_llm_config: LLMData | None = None
_provider: BaseProvider | None = None
@property
def VectorBase(self) -> Any:
"""获取VectorBase"""
return self._table_manager.VectorBase
@property
def NodePoolVector(self) -> Any:
"""获取NodePoolVector"""
return self._table_manager.NodePoolVector
@property
def FlowPoolVector(self) -> Any:
"""获取FlowPoolVector"""
return self._table_manager.FlowPoolVector
@property
def ServicePoolVector(self) -> Any:
"""获取ServicePoolVector"""
return self._table_manager.ServicePoolVector
@property
def MCPVector(self) -> Any:
"""获取MCPVector"""
return self._table_manager.MCPVector
@property
def MCPToolVector(self) -> Any:
"""获取MCPToolVector"""
return self._table_manager.MCPToolVector
async def _get_embedding_dimension(self) -> int:
"""获取Embedding的维度"""
embedding = await self.get_embedding(["测试文本"])
return len(embedding[0])
async def init(self, llm_config: LLMData | None) -> None:
"""初始化Embedding配置和资源"""
if llm_config is None:
err = "[Embedding] 未设置LLM配置"
_logger.error(err)
raise RuntimeError(err)
_logger.info("[Embedding] 初始化Embedding,模型=%s/%s", llm_config.provider, llm_config.modelName)
self._llm_config = llm_config
self._provider = _CLASS_DICT[llm_config.provider](llm_config)
dim = await self._get_embedding_dimension()
_logger.info("[Embedding] 检测到向量维度: %d", dim)
await self._table_manager.ensure_tables(dim, self._llm_config)
_logger.info("[Embedding] 向量表检查完成")
async def get_embedding(self, text: list[str]) -> list[list[float]]:
"""获取文本的向量表示"""
if not self._provider:
err = "[Embedding] Provider未初始化,无法获取embedding"
_logger.error(err)
raise RuntimeError(err)
return await self._provider.embedding(text)
embedding = Embedding()