"""工作流中步骤相关函数"""
import inspect
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from pydantic import ConfigDict
from apps.models import ExecutorHistory, ExecutorStatus, StepStatus
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.empty import Empty
from apps.scheduler.call.facts.facts import FactsCall
from apps.scheduler.call.slot.schema import SlotOutput
from apps.scheduler.call.slot.slot import Slot
from apps.scheduler.call.summary.summary import Summary
from apps.scheduler.pool.pool import pool
from apps.schemas.enum_var import (
EventType,
SpecialCallType,
)
from apps.schemas.message import TextAddContent
from apps.schemas.scheduler import CallError, CallOutputChunk, ExecutorBackground
from apps.schemas.task import StepQueueItem
from apps.services.node import NodeManager
from .base import BaseExecutor
logger = logging.getLogger(__name__)
class StepExecutor(BaseExecutor):
"""工作流中步骤相关函数"""
step: StepQueueItem
background: ExecutorBackground
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
@staticmethod
async def check_cls(call_cls: Any) -> bool:
"""检查Call是否符合标准要求"""
flag = True
if not hasattr(call_cls, "_init") or not callable(call_cls._init):
flag = False
if not hasattr(call_cls, "_exec") or not inspect.isasyncgenfunction(call_cls._exec):
flag = False
if not hasattr(call_cls, "info") or not callable(call_cls.info):
flag = False
return flag
@staticmethod
async def get_call_cls(call_id: str) -> type[CoreCall]:
"""获取并验证Call类"""
if call_id == SpecialCallType.EMPTY.value:
return Empty
if call_id == SpecialCallType.SUMMARY.value:
return Summary
if call_id == SpecialCallType.FACTS.value:
return FactsCall
if call_id == SpecialCallType.SLOT.value:
return Slot
call_cls: type[CoreCall] = await pool.get_call(call_id)
if not await StepExecutor.check_cls(call_cls):
err = f"[FlowExecutor] 工具 {call_id} 不符合Call标准要求"
raise ValueError(err)
return call_cls
async def init(self) -> None:
"""初始化步骤"""
logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name)
if not self.task.state:
err = "[StepExecutor] 任务状态不存在"
logger.error(err)
raise RuntimeError(err)
self.task.state.stepId = self.step.step_id
self.task.state.stepType = self.step.step.type
self.task.state.stepName = self.step.step.name
node_id = self.step.step.node
try:
self.node = await NodeManager.get_node(node_id)
except ValueError:
logger.info("[StepExecutor] 获取Node失败,为内部Node或ID不存在")
self.node = None
if self.node:
call_cls = await StepExecutor.get_call_cls(self.node.callId)
self._call_id = self.node.callId
else:
call_cls = await StepExecutor.get_call_cls(node_id)
self._call_id = node_id
params: dict[str, Any] = (
self.node.knownParams if self.node and self.node.knownParams else {}
)
if self.step.step.params:
params.update(self.step.step.params)
try:
self.obj = await call_cls.instance(self, self.node, **params)
except Exception:
logger.exception("[StepExecutor] 初始化Call失败")
raise
async def _run_slot_filling(self) -> None:
"""运行自动参数填充;相当于特殊Step,但是不存库"""
if not self.obj.enable_filling:
return
await self._check_cancelled()
if not self.task.state:
err = "[StepExecutor] 任务状态不存在"
logger.error(err)
raise RuntimeError(err)
current_step_id = self.task.state.stepId
current_step_name = self.task.state.stepName
self.task.state.stepId = uuid.uuid4()
self.task.state.stepName = "自动参数填充"
self.task.state.stepStatus = StepStatus.RUNNING
self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2)
slot_obj = await Slot.instance(
self,
self.node,
data=self.obj.input,
current_schema=self.obj.input_model.model_json_schema(
override=self.node.overrideInput if self.node and self.node.overrideInput else {},
),
)
await self._push_message(EventType.STEP_INPUT.value, slot_obj.input)
iterator = slot_obj.exec(self, slot_obj.input)
async for chunk in iterator:
await self._check_cancelled()
result: SlotOutput = SlotOutput.model_validate(chunk.content)
if result.remaining_schema:
self.task.state.stepStatus = StepStatus.PARAM
else:
self.task.state.stepStatus = StepStatus.SUCCESS
await self._push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True))
self.obj.input.update(result.slot_data)
self.task.state.stepId = current_step_id
self.task.state.stepName = current_step_name
async def _process_chunk(
self,
iterator: AsyncGenerator[CallOutputChunk, None],
*,
to_user: bool = False,
) -> str | dict[str, Any]:
"""处理Chunk"""
content: str | dict[str, Any] = ""
chunk_count = 0
async for chunk in iterator:
chunk_count += 1
if chunk_count % 10 == 0:
await self._check_cancelled()
if not isinstance(chunk, CallOutputChunk):
err = "[StepExecutor] 返回结果类型错误"
logger.error(err)
raise TypeError(err)
if isinstance(chunk.content, str):
if not isinstance(content, str):
content = ""
content += chunk.content
else:
if not isinstance(content, dict):
content = {}
content = chunk.content
if to_user:
if isinstance(chunk.content, str):
await self._push_message(EventType.TEXT_ADD.value, chunk.content)
self.task.runtime.fullAnswer += chunk.content
else:
await self._push_message(self.step.step.type, chunk.content)
return content
async def run(self) -> None:
"""运行单个步骤"""
logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name)
await self._check_cancelled()
if not self.task.state:
err = "[StepExecutor] 任务状态不存在"
logger.error(err)
raise RuntimeError(err)
await self._run_slot_filling()
self.task.state.stepStatus = StepStatus.RUNNING
self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2)
await self._push_message(EventType.STEP_INPUT.value, self.obj.input)
await self._check_cancelled()
iterator = self.obj.exec(self, self.obj.input)
try:
content = await self._process_chunk(iterator, to_user=self.obj.to_user)
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
self.task.state.stepStatus = StepStatus.ERROR
self.task.state.executorStatus = ExecutorStatus.ERROR
await self._push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
self.task.state.errorMessage = {
"err_msg": e.message,
"data": e.data,
}
else:
self.task.state.errorMessage = {
"data": {},
}
return
self.task.state.stepStatus = StepStatus.SUCCESS
self.task.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.time
if isinstance(content, str):
output_data = TextAddContent(text=content).model_dump(exclude_none=True, by_alias=True)
else:
output_data = content
history = ExecutorHistory(
taskId=self.task.metadata.id,
stepId=self.step.step_id,
stepName=self.step.step.name,
stepType=self.task.state.stepType,
stepStatus=self.task.state.stepStatus,
inputData=self.obj.input,
outputData=output_data,
)
self.task.context.append(history)
await self._push_message(EventType.STEP_OUTPUT.value, output_data)