"""Flow加载器"""
import asyncio
import logging
from hashlib import sha256
from typing import Any
import aiofiles
import yaml
from anyio import Path
from apps.common.config import Config
from apps.schemas.enum_var import NodeType,EdgeType
from apps.schemas.flow import AppFlow, Flow
from apps.schemas.pool import AppPool
from apps.models.vector import FlowPoolVector
from apps.llm.embedding import Embedding
from apps.services.node import NodeManager
from apps.common.lance import LanceDB
from apps.common.mongo import MongoDB
from apps.scheduler.util import yaml_enum_presenter, yaml_str_presenter
logger = logging.getLogger(__name__)
BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "app"
class FlowLoader:
"""工作流加载器"""
async def _load_yaml_file(self, flow_path: Path) -> dict[str, Any]:
"""从YAML文件加载工作流配置"""
try:
async with aiofiles.open(flow_path, encoding="utf-8") as f:
return yaml.safe_load(await f.read())
except Exception:
logger.exception("[FlowLoader] 加载YAML文件失败:%s", flow_path)
return {}
async def _validate_basic_fields(self, flow_yaml: dict[str, Any], flow_path: Path) -> dict[str, Any]:
"""验证工作流基本字段"""
if "name" not in flow_yaml or not flow_yaml["name"]:
logger.error("[FlowLoader] 工作流名称不能为空:%s", flow_path)
return {}
if "description" not in flow_yaml or not flow_yaml["description"]:
logger.error("[FlowLoader] 工作流描述不能为空:%s", flow_path)
return {}
if "start" not in flow_yaml["steps"] or "end" not in flow_yaml["steps"]:
logger.error("[FlowLoader] 工作流必须包含开始和结束节点:%s", flow_path)
return {}
return flow_yaml
async def _process_edges(self, flow_yaml: dict[str, Any], flow_id: str, app_id: str) -> dict[str, Any]:
"""处理工作流边的转换"""
logger.info("[FlowLoader] 应用 %s:解析工作流 %s 的边", flow_id, app_id)
try:
for edge in flow_yaml["edges"]:
if "from" in edge:
edge["edge_from"] = edge.pop("from")
if "to" in edge:
edge["edge_to"] = edge.pop("to")
if "type" in edge:
edge["edge_type"] = EdgeType[edge.pop("type").upper()]
except Exception:
logger.exception("[FlowLoader] 处理边时发生错误")
return {}
else:
return flow_yaml
async def _process_steps(self, flow_yaml: dict[str, Any], flow_id: str, app_id: str) -> dict[str, Any]:
"""处理工作流步骤的转换"""
logger.info("[FlowLoader] 应用 %s:解析工作流 %s 的步骤", flow_id, app_id)
for key, step in flow_yaml["steps"].items():
if key[0] == "_":
err = f"[FlowLoader] 步骤名称不能以下划线开头:{key}"
logger.error(err)
raise ValueError(err)
if step["type"]==NodeType.START.value or step["type"]==NodeType.END.value:
continue
try:
step["type"] = await NodeManager.get_node_call_id(step["node"])
except ValueError as e:
logger.warning("[FlowLoader] 获取节点call_id失败:%s,错误信息:%s", step["node"], e)
step["type"] = "Empty"
step["name"] = (
(await NodeManager.get_node_name(step["node"]))
if "name" not in step or step["name"] == ""
else step["name"]
)
return flow_yaml
async def load(self, app_id: str, flow_id: str) -> Flow | None:
"""从文件系统中加载【单个】工作流"""
logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id)
flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml"
if not await flow_path.exists():
logger.error("[FlowLoader] 应用 %s:工作流文件 %s 不存在", app_id, flow_path)
return None
try:
flow_yaml = await self._load_yaml_file(flow_path)
if not flow_yaml:
return None
for processor in [
lambda y: self._validate_basic_fields(y, flow_path),
lambda y: self._process_edges(y, flow_id, app_id),
lambda y: self._process_steps(y, flow_id, app_id),
]:
flow_yaml = await processor(flow_yaml)
if not flow_yaml:
return None
flow_config = Flow.model_validate(flow_yaml)
await self._update_db(
app_id,
AppFlow(
id=flow_id,
name=flow_config.name,
description=flow_config.description,
enabled=True,
path=str(flow_path),
debug=flow_config.debug,
),
)
return Flow.model_validate(flow_yaml)
except Exception:
logger.exception("[FlowLoader] 应用 %s:工作流 %s 格式不合法", app_id, flow_id)
return None
async def save(self, app_id: str, flow_id: str, flow: Flow) -> None:
"""保存工作流"""
flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml"
if not await flow_path.parent.exists():
await flow_path.parent.mkdir(parents=True)
flow_dict = flow.model_dump(by_alias=True, exclude_none=True)
async with aiofiles.open(flow_path, mode="w", encoding="utf-8") as f:
yaml.add_representer(str, yaml_str_presenter)
yaml.add_representer(EdgeType, yaml_enum_presenter)
await f.write(
yaml.dump(
flow_dict,
allow_unicode=True,
sort_keys=False,
),
)
await self._update_db(
app_id,
AppFlow(
id=flow_id,
name=flow.name,
description=flow.description,
enabled=True,
path=str(flow_path),
debug=flow.debug,
),
)
async def delete(self, app_id: str, flow_id: str) -> bool:
"""删除指定工作流文件"""
flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml"
if await flow_path.exists():
try:
await flow_path.unlink()
logger.info("[FlowLoader] 成功删除工作流文件:%s", flow_path)
except Exception:
logger.exception("[FlowLoader] 删除工作流文件失败:%s", flow_path)
return False
table = await LanceDB().get_table("flow")
try:
await table.delete(f"id = '{flow_id}'")
except Exception:
logger.exception("[FlowLoader] LanceDB删除flow失败")
return True
logger.warning("[FlowLoader] 工作流文件不存在或不是文件:%s", flow_path)
return True
async def _update_db(self, app_id: str, metadata: AppFlow) -> None:
"""更新数据库"""
try:
app_collection = MongoDB().get_collection("app")
app_data = await app_collection.find_one({"_id": app_id})
if not app_data:
err = f"[FlowLoader] App {app_id} 不存在"
logger.error(err)
return
app_obj = AppPool.model_validate(app_data)
flows = app_obj.flows
for flow in flows:
if flow.id == metadata.id:
flows.remove(flow)
break
flows.append(metadata)
await app_collection.update_one(
filter={
"_id": app_id,
},
update={
"$set": {
"flows": [flow.model_dump(by_alias=True, exclude_none=True) for flow in flows],
},
},
upsert=True,
)
flow_path = BASE_PATH / app_id / "flow" / f"{metadata.id}.yaml"
async with aiofiles.open(flow_path, "rb") as f:
new_hash = sha256(await f.read()).hexdigest()
key = f"hashes.flow/{metadata.id}.yaml"
await app_collection.aggregate(
[
{"$match": {"_id": app_id}},
{"$replaceWith": {"$setField": {"field": key, "input": "$$ROOT", "value": new_hash}}},
],
)
except Exception:
logger.exception("[FlowLoader] 更新 MongoDB 失败")
while True:
try:
table = await LanceDB().get_table("flow")
await table.delete(f"id = '{metadata.id}'")
break
except RuntimeError as e:
if "Commit conflict" in str(e):
logger.error("[FlowLoader] LanceDB删除flow冲突,重试中...")
await asyncio.sleep(0.01)
else:
raise
service_embedding = await Embedding.get_embedding([metadata.description])
vector_data = [
FlowPoolVector(
id=metadata.id,
app_id=app_id,
embedding=service_embedding[0],
),
]
while True:
try:
table = await LanceDB().get_table("flow")
await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(
vector_data,
)
break
except RuntimeError as e:
if "Commit conflict" in str(e):
logger.error("[FlowLoader] LanceDB插入flow冲突,重试中...")
await asyncio.sleep(0.01)
else:
raise