"""加载配置文件夹的Service部分;下属于Pool,单例"""
import logging
import shutil
import uuid
from anyio import Path
from sqlalchemy import delete, inspect, select
from apps.common.config import config
from apps.common.postgres import postgres
from apps.llm import embedding
from apps.models import (
NodeInfo,
PermissionType,
Service,
ServiceACL,
ServiceHashes,
)
from apps.scheduler.pool.check import FileChecker
from apps.schemas.flow import ServiceMetadata
from .metadata import MetadataLoader, MetadataType
from .openapi import OpenAPILoader
logger = logging.getLogger(__name__)
BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "service"
async def _table_exists(table_name: str) -> bool:
"""检查表是否存在"""
async with postgres.engine.connect() as conn:
return await conn.run_sync(
lambda sync_conn: inspect(sync_conn).has_table(table_name),
)
class ServiceLoader:
"""Service 加载器"""
@staticmethod
async def _load_all_services() -> list[tuple[Service, list[NodeInfo]]]:
"""从数据库加载所有服务和对应的节点"""
async with postgres.session() as session:
services_query = select(Service)
services = list((await session.scalars(services_query)).all())
service_nodes = []
for service in services:
nodes_query = select(NodeInfo).where(NodeInfo.serviceId == service.id)
nodes = list((await session.scalars(nodes_query)).all())
service_nodes.append((service, nodes))
return service_nodes
@staticmethod
async def load(service_id: uuid.UUID, hashes: dict[str, str]) -> None:
"""加载单个Service"""
service_path = BASE_PATH / str(service_id)
if not await (service_path / "metadata.yaml").exists():
logger.error("[ServiceLoader] Service %s 的元数据不存在", service_id)
return
metadata = await MetadataLoader().load_one(service_path / "metadata.yaml")
if not isinstance(metadata, ServiceMetadata):
err = f"[ServiceLoader] 元数据类型错误: {service_path}/metadata.yaml"
logger.error(err)
raise TypeError(err)
metadata.hashes = hashes
try:
nodes: list[NodeInfo] = []
async for yaml_path in (service_path / "openapi").rglob("*.yaml"):
nodes.extend(await OpenAPILoader().load_one(service_id, yaml_path, metadata.api.server))
except Exception:
logger.exception("[ServiceLoader] 服务 %s 文件损坏", service_id)
return
await ServiceLoader._update_db(nodes, metadata)
@staticmethod
async def save(service_id: uuid.UUID, metadata: ServiceMetadata, data: dict) -> None:
"""在文件系统上保存Service,并更新数据库"""
service_path = BASE_PATH / str(service_id)
if not await service_path.exists():
await service_path.mkdir(parents=True, exist_ok=True)
if not await (service_path / "openapi").exists():
await (service_path / "openapi").mkdir(parents=True, exist_ok=True)
openapi_path = service_path / "openapi" / "api.yaml"
await MetadataLoader().save_one(MetadataType.SERVICE, metadata, service_id)
await OpenAPILoader().save_one(openapi_path, data)
file_checker = FileChecker()
await file_checker.diff_one(service_path)
await ServiceLoader.load(service_id, file_checker.hashes[f"service/{service_id}"])
@staticmethod
async def delete(service_id: uuid.UUID, *, is_reload: bool = False) -> None:
"""删除Service,并更新数据库"""
async with postgres.session() as session:
await session.execute(delete(Service).where(Service.id == service_id))
await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == service_id))
await session.execute(delete(ServiceACL).where(ServiceACL.serviceId == service_id))
await session.execute(delete(ServiceHashes).where(ServiceHashes.serviceId == service_id))
if embedding.ServicePoolVector is not None:
await session.execute(
delete(embedding.ServicePoolVector).where(embedding.ServicePoolVector.id == service_id),
)
if embedding.NodePoolVector is not None:
await session.execute(
delete(
embedding.NodePoolVector,
).where(embedding.NodePoolVector.serviceId == service_id),
)
await session.commit()
if not is_reload:
path = BASE_PATH / str(service_id)
if await path.exists():
shutil.rmtree(path)
@staticmethod
async def _update_db(nodes: list[NodeInfo], metadata: ServiceMetadata) -> None:
"""更新数据库"""
if not metadata.hashes:
err = f"[ServiceLoader] 服务 {metadata.id} 的哈希值为空"
logger.error(err)
raise ValueError(err)
async with postgres.session() as session:
await session.execute(delete(Service).where(Service.id == metadata.id))
await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == metadata.id))
await session.execute(delete(ServiceHashes).where(ServiceHashes.serviceId == metadata.id))
service_data = Service(
id=metadata.id,
name=metadata.name,
description=metadata.description,
authorId=metadata.author,
permission=PermissionType.PRIVATE,
)
session.add(service_data)
for node in nodes:
session.add(node)
for file_path, hash_value in metadata.hashes.items():
hash_data = ServiceHashes(
serviceId=metadata.id,
filePath=file_path,
hash=hash_value,
)
session.add(hash_data)
await session.commit()
@staticmethod
async def set_vector() -> None:
"""将所有服务和节点的向量化数据存入数据库"""
service_nodes = await ServiceLoader._load_all_services()
for service, nodes in service_nodes:
await ServiceLoader._update_vector(
nodes,
service.id,
service.description,
)
@staticmethod
async def _update_vector(
nodes: list[NodeInfo],
service_id: uuid.UUID,
service_description: str,
) -> None:
"""更新向量数据"""
if embedding.ServicePoolVector is None or embedding.NodePoolVector is None:
logger.warning(
"[ServiceLoader] Embedding 未初始化,跳过向量数据更新",
)
return
if not await _table_exists(embedding.ServicePoolVector.__tablename__):
logger.warning(
"[ServiceLoader] 表 %s 不存在,跳过向量数据更新",
embedding.ServicePoolVector.__tablename__,
)
return
service_vecs = await embedding.get_embedding([service_description])
node_descriptions = []
for node in nodes:
node_descriptions += [node.description]
node_vecs = await embedding.get_embedding(node_descriptions)
async with postgres.session() as session:
await session.execute(
delete(embedding.ServicePoolVector).where(embedding.ServicePoolVector.id == service_id),
)
await session.execute(
delete(embedding.NodePoolVector).where(embedding.NodePoolVector.serviceId == service_id),
)
session.add(embedding.ServicePoolVector(
id=service_id,
embedding=service_vecs[0],
))
for vec in node_vecs:
node_data = embedding.NodePoolVector(
id=node.id,
serviceId=service_id,
embedding=vec,
)
session.add(node_data)
await session.commit()