"""工作流中步骤相关函数"""
import inspect
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from pydantic import ConfigDict
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.empty import Empty
from apps.scheduler.call.facts.facts import FactsCall
from apps.scheduler.call.slot.schema import SlotOutput
from apps.scheduler.call.slot.slot import Slot
from apps.scheduler.call.summary.summary import Summary
from apps.scheduler.executor.base import BaseExecutor
from apps.scheduler.pool.pool import Pool
from apps.schemas.enum_var import (
EventType,
SpecialCallType,
StepStatus,
)
from apps.schemas.message import TextAddContent
from apps.schemas.scheduler import CallError, CallOutputChunk
from apps.schemas.task import FlowStepHistory, StepQueueItem
from apps.services.node import NodeManager
logger = logging.getLogger(__name__)
class StepExecutor(BaseExecutor):
"""工作流中步骤相关函数"""
step: StepQueueItem
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
def __init__(self, **kwargs: Any) -> None:
"""初始化"""
super().__init__(**kwargs)
self.validate_flow_state(self.task)
@staticmethod
async def check_cls(call_cls: Any) -> bool:
"""检查Call是否符合标准要求"""
flag = True
if not hasattr(call_cls, "_init") or not callable(call_cls._init):
flag = False
if not hasattr(call_cls, "_exec") or not inspect.isasyncgenfunction(call_cls._exec):
flag = False
if not hasattr(call_cls, "info") or not callable(call_cls.info):
flag = False
return flag
@staticmethod
async def get_call_cls(call_id: str) -> type[CoreCall]:
"""获取并验证Call类"""
if call_id == SpecialCallType.EMPTY.value:
return Empty
if call_id == SpecialCallType.SUMMARY.value:
return Summary
if call_id == SpecialCallType.FACTS.value:
return FactsCall
if call_id == SpecialCallType.SLOT.value:
return Slot
call_cls: type[CoreCall] = await Pool().get_call(call_id)
if not await StepExecutor.check_cls(call_cls):
err = f"[FlowExecutor] 工具 {call_id} 不符合Call标准要求"
raise ValueError(err)
return call_cls
async def init(self) -> None:
"""初始化步骤"""
logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name)
self.task.state.step_id = self.step.step_id
self.task.state.step_name = self.step.step.name
node_id = self.step.step.node
try:
self.node = await NodeManager.get_node(node_id)
except ValueError:
logger.info("[StepExecutor] 获取Node失败,为内部Node或ID不存在")
self.node = None
if self.node:
call_cls = await StepExecutor.get_call_cls(self.node.call_id)
self._call_id = self.node.call_id
else:
call_cls = await StepExecutor.get_call_cls(node_id)
self._call_id = node_id
params: dict[str, Any] = (
self.node.known_params if self.node and self.node.known_params else {}
)
if self.step.step.params:
params.update(self.step.step.params)
try:
self.obj = await call_cls.instance(self, self.node, **params)
except Exception:
logger.exception("[StepExecutor] 初始化Call失败")
raise
async def _run_slot_filling(self) -> None:
"""运行自动参数填充;相当于特殊Step,但是不存库"""
if not self.obj.enable_filling:
return
current_step_id = self.task.state.step_id
current_step_name = self.task.state.step_name
self.task.state.step_id = str(uuid.uuid4())
self.task.state.step_name = "自动参数填充"
self.task.state.step_status = StepStatus.RUNNING
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
slot_obj = await Slot.instance(
self,
self.node,
data=self.obj.input,
current_schema=self.obj.input_model.model_json_schema(
override=self.node.override_input if self.node and self.node.override_input else {},
),
)
await self.push_message(EventType.STEP_INPUT.value, slot_obj.input)
iterator = slot_obj.exec(self, slot_obj.input)
async for chunk in iterator:
result: SlotOutput = SlotOutput.model_validate(chunk.content)
self.task.tokens.input_tokens += slot_obj.tokens.input_tokens
self.task.tokens.output_tokens += slot_obj.tokens.output_tokens
if result.remaining_schema:
self.task.state.step_status = StepStatus.PARAM
else:
self.task.state.step_status = StepStatus.SUCCESS
await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True))
self.obj.input.update(result.slot_data)
self.task.state.step_id = current_step_id
self.task.state.step_name = current_step_name
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
async def _process_chunk(
self,
iterator: AsyncGenerator[CallOutputChunk, None],
*,
to_user: bool = False,
) -> str | dict[str, Any]:
"""处理Chunk"""
content: str | dict[str, Any] = ""
async for chunk in iterator:
if not isinstance(chunk, CallOutputChunk):
err = "[StepExecutor] 返回结果类型错误"
logger.error(err)
raise TypeError(err)
if isinstance(chunk.content, str):
if not isinstance(content, str):
content = ""
content += chunk.content
else:
if not isinstance(content, dict):
content = {}
content = chunk.content
if to_user:
if isinstance(chunk.content, str):
await self.push_message(EventType.TEXT_ADD.value, chunk.content)
self.task.runtime.answer += chunk.content
else:
await self.push_message(self.step.step.type, chunk.content)
return content
async def run(self) -> None:
"""运行单个步骤"""
self.validate_flow_state(self.task)
logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name)
await self._run_slot_filling()
self.task.state.step_status = StepStatus.RUNNING
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
await self.push_message(EventType.STEP_INPUT.value, self.obj.input)
iterator = self.obj.exec(self, self.obj.input, language=self.task.language)
try:
content = await self._process_chunk(iterator, to_user=self.obj.to_user)
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
self.task.state.step_status = StepStatus.ERROR
await self.push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
self.task.state.error_info = {
"err_msg": e.message,
"data": e.data,
}
else:
self.task.state.error_info = {
"err_msg": str(e),
"data": {},
}
return
self.task.state.step_status = StepStatus.SUCCESS
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time
if isinstance(content, str):
output_data = TextAddContent(text=content).model_dump(exclude_none=True, by_alias=True)
else:
output_data = content
history = FlowStepHistory(
task_id=self.task.id,
flow_id=self.task.state.flow_id,
flow_name=self.task.state.flow_name,
flow_status=self.task.state.flow_status,
step_id=self.step.step_id,
step_name=self.step.step.name,
step_description=self.step.step.description,
step_status=self.task.state.step_status,
input_data=self.obj.input,
output_data=output_data,
)
self.task.context.append(history)
await self.push_message(EventType.STEP_OUTPUT.value, output_data)