"""Call 加载器"""
import importlib
import logging
from sqlalchemy import delete, inspect
from apps.common.postgres import postgres
from apps.llm import embedding
from apps.models import NodeInfo
from apps.schemas.scheduler import CallInfo
_logger = logging.getLogger(__name__)
async def _table_exists(table_name: str) -> bool:
"""检查表是否存在"""
async with postgres.engine.connect() as conn:
return await conn.run_sync(
lambda sync_conn: inspect(sync_conn).has_table(table_name),
)
class CallLoader:
"""Call 加载器"""
async def _load_system_call(self) -> dict[str, CallInfo]:
"""加载系统Call"""
call_metadata = {}
system_call = importlib.import_module("apps.scheduler.call")
for call_id in system_call.__all__:
call_cls = getattr(system_call, call_id)
call_info = call_cls.info()
call_metadata[call_id] = call_info
return call_metadata
async def _add_data_to_db(self, call_metadata: dict[str, CallInfo]) -> None:
"""将数据插入数据库"""
async with postgres.session() as session:
if embedding.NodePoolVector is not None and await _table_exists(embedding.NodePoolVector.__tablename__):
await session.execute(
delete(embedding.NodePoolVector).where(embedding.NodePoolVector.serviceId == None),
)
await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == None))
call_descriptions = []
for call_id, call in call_metadata.items():
session.add(NodeInfo(
id=call_id,
name=call.name,
description=call.description,
serviceId=None,
callId=call_id,
knownParams={},
overrideInput={},
overrideOutput={},
))
call_descriptions.append(call.description)
await session.commit()
async def _add_vector_to_db(self, call_metadata: dict[str, CallInfo]) -> None:
"""将向量化数据存入数据库"""
if embedding.NodePoolVector is None:
_logger.warning(
"[CallLoader] Embedding 未初始化,跳过向量数据插入",
)
return
if not await _table_exists(embedding.NodePoolVector.__tablename__):
_logger.warning(
"[CallLoader] 表 %s 不存在,跳过向量数据插入",
embedding.NodePoolVector.__tablename__,
)
return
async with postgres.session() as session:
await session.execute(
delete(embedding.NodePoolVector).where(embedding.NodePoolVector.serviceId == None),
)
call_vecs = await embedding.get_embedding([call.description for call in call_metadata.values()])
for call_id, vec in zip(call_metadata.keys(), call_vecs, strict=True):
session.add(embedding.NodePoolVector(
id=call_id,
embedding=vec,
))
await session.commit()
async def set_vector(self) -> None:
"""将向量化数据存入数据库"""
call_metadata = await self._load_system_call()
await self._add_vector_to_db(call_metadata)
async def load(self) -> None:
"""初始化Call信息"""
try:
sys_call_metadata = await self._load_system_call()
except Exception as e:
err = "[CallLoader] 载入系统Call信息失败"
_logger.exception(err)
raise RuntimeError(err) from e
await self._add_data_to_db(sys_call_metadata)
await self._add_vector_to_db(sys_call_metadata)