"""Node管理器"""
import logging
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from apps.common.postgres import postgres
from apps.models import NodeInfo
from apps.schemas.enum_var import SpecialCallType
if TYPE_CHECKING:
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class NodeManager:
"""Node管理器"""
@staticmethod
async def get_node(node_id: str) -> NodeInfo:
"""获取Node信息"""
async with postgres.session() as session:
node = (await session.scalars(
select(NodeInfo).where(NodeInfo.id == node_id),
)).one_or_none()
if not node:
err = f"[NodeManager] Node {node_id} not found."
raise ValueError(err)
return node
@staticmethod
def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]:
"""递归合并参数Schema,将known_params中的值填充到params_schema的对应位置"""
if not isinstance(params_schema, dict):
return params_schema
if params_schema.get("type") == "object":
properties = params_schema.get("properties", {})
for key, value in properties.items():
if key in known_params:
properties[key]["default"] = known_params[key]
properties[key] = NodeManager.merge_params_schema(value, known_params)
elif params_schema.get("type") == "array":
items = params_schema.get("items", {})
params_schema["items"] = NodeManager.merge_params_schema(items, known_params)
return params_schema
@staticmethod
async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]:
"""获取Node数据"""
from apps.scheduler.pool.pool import pool
if node_id == SpecialCallType.EMPTY.value:
return {}, {}
logger.info("[NodeManager] 获取节点 %s", node_id)
node_data = await NodeManager.get_node(node_id)
call_id = node_data.callId
logger.info("[NodeManager] 获取Call %s", call_id)
call_class: type[BaseModel] = await pool.get_call(str(call_id))
if not call_class:
err = f"[NodeManager] Call {call_id} 不存在"
logger.error(err)
raise ValueError(err)
return (
NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.knownParams or {}),
call_class.output_model.model_json_schema(
override=node_data.overrideOutput if node_data.overrideOutput else {},
),
)