"""flow拓扑相关函数"""
import logging
import uuid
from typing import TYPE_CHECKING
from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError
from apps.scheduler.pool.pool import pool
from apps.schemas.enum_var import NodeType, SpecialCallType
from apps.schemas.flow_topology import EdgeItem, FlowItem
if TYPE_CHECKING:
from pydantic import BaseModel
logger = logging.getLogger(__name__)
BRANCH_ILLEGAL_CHARS = [
".",
]
class FlowServiceManager:
"""flow拓扑相关函数"""
@staticmethod
async def remove_excess_structure_from_flow(flow_item: FlowItem) -> FlowItem:
"""移除流程图中的多余结构"""
node_branch_map = {}
for node in flow_item.nodes:
if node.node_id not in {"start", "end", SpecialCallType.EMPTY.value}:
try:
call_class: type[BaseModel] = await pool.get_call(node.call_id)
if not call_class:
node.node_id = SpecialCallType.EMPTY.value
node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description
except Exception:
node.node_id = SpecialCallType.EMPTY.value
node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description
logger.exception("[FlowService] 获取步骤的call_id失败%s", node.call_id)
node_branch_map[node.step_id] = set()
if node.call_id == NodeType.CHOICE.value:
input_parameters = node.parameters["input_parameters"]
if "choices" not in input_parameters:
err = f"[FlowService] 节点{node.name}的分支字段缺失"
logger.error(err)
raise FlowBranchValidationError(err)
if not input_parameters["choices"]:
err = f"[FlowService] 节点{node.name}的分支字段为空"
logger.error(err)
raise FlowBranchValidationError(err)
for choice in input_parameters["choices"]:
if "branch_id" not in choice:
err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段"
logger.error(err)
raise FlowBranchValidationError(err)
if choice["branch_id"] in node_branch_map[node.step_id]:
err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复"
logger.error(err)
raise ValueError(err)
for illegal_char in BRANCH_ILLEGAL_CHARS:
if illegal_char in choice["branch_id"]:
err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符"
logger.error(err)
raise ValueError(err)
node_branch_map[node.step_id].add(choice["branch_id"])
else:
node_branch_map[node.step_id].add("")
valid_edges = []
for edge in flow_item.edges:
if edge.source_branch not in node_branch_map:
continue
if edge.target_branch not in node_branch_map:
continue
if edge.branch_id not in node_branch_map[edge.source_branch]:
continue
valid_edges.append(edge)
flow_item.edges = valid_edges
return flow_item
@staticmethod
async def _validate_node_ids(flow_item: FlowItem) -> tuple[uuid.UUID, uuid.UUID]:
"""验证节点ID的唯一性并验证basicConfig中的start_id和end_id是否在节点列表中存在"""
nodes = flow_item.nodes
ids = set()
for node in nodes:
if node.step_id in ids:
err = f"[FlowService] 节点{node.name}的id重复"
logger.error(err)
raise FlowNodeValidationError(err)
ids.add(node.step_id)
if flow_item.basic_config is None:
err = "[FlowService] Flow的基本配置为空"
logger.error(err)
raise FlowNodeValidationError(err)
start_id = flow_item.basic_config.startStep
end_id = flow_item.basic_config.endStep
if start_id not in ids:
err = f"[FlowService] 起始节点ID {start_id} 在节点列表中不存在"
logger.error(err)
raise FlowNodeValidationError(err)
if end_id not in ids:
err = f"[FlowService] 终止节点ID {end_id} 在节点列表中不存在"
logger.error(err)
raise FlowNodeValidationError(err)
return start_id, end_id
@staticmethod
async def validate_flow_illegal(flow_item: FlowItem) -> tuple[uuid.UUID, uuid.UUID]:
"""验证流程图是否合法;当流程图不合法时抛出异常"""
start_id, end_id = await FlowServiceManager._validate_node_ids(flow_item)
await FlowServiceManager._validate_edges(flow_item.edges)
return start_id, end_id
@staticmethod
async def _validate_edges(edges: list[EdgeItem]) -> None:
"""验证边的合法性;当边的ID重复、起始终止节点相同、分支重复或分支包含非法字符时抛出异常"""
ids = set()
branches = {}
for e in edges:
for illegal_char in BRANCH_ILLEGAL_CHARS:
if illegal_char in e.branch_id:
err = f"[FlowService] 边{e.edge_id}的分支{e.branch_id}名称中含有非法字符"
logger.error(err)
raise FlowEdgeValidationError(err)
if e.edge_id in ids:
err = f"[FlowService] 边{e.edge_id}的id重复"
logger.error(err)
raise FlowEdgeValidationError(err)
ids.add(e.edge_id)
if e.source_branch == e.target_branch:
err = f"[FlowService] 边{e.edge_id}的起始节点和终止节点相同"
logger.error(err)
raise FlowEdgeValidationError(err)
if e.source_branch not in branches:
branches[e.source_branch] = set()
if e.branch_id in branches[e.source_branch]:
err = f"[FlowService] 边{e.edge_id}的分支{e.branch_id}重复"
logger.error(err)
raise FlowEdgeValidationError(err)
branches[e.source_branch].add(e.branch_id)
@staticmethod
async def validate_flow_connectivity(flow_item: FlowItem) -> bool:
"""
验证流程图的连通性
检查:
1. 是否所有节点都能从起始节点到达
2. 是否能从起始节点到达终止节点
3. 是否存在非终止节点没有出边
"""
start_id = None
end_id = None
for node in flow_item.nodes:
if node.call_id == NodeType.START.value:
start_id = node.step_id
if node.call_id == NodeType.END.value:
end_id = node.step_id
adj = {}
for edge in flow_item.edges:
if edge.source_branch not in adj:
adj[edge.source_branch] = []
adj[edge.source_branch].append(edge.target_branch)
vis = {start_id}
q = [start_id]
while q:
cur = q.pop(0)
if cur != end_id and cur not in adj:
return False
if cur in adj:
for nxt in adj[cur]:
if nxt not in vis:
vis.add(nxt)
q.append(nxt)
return end_id in vis