"""Flow执行Executor"""
import logging
import uuid
from collections import deque
from datetime import UTC, datetime
from pydantic import Field
from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT
from apps.scheduler.executor.base import BaseExecutor
from apps.scheduler.executor.step import StepExecutor
from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus, LanguageType
from apps.schemas.flow import Flow, Step
from apps.schemas.request_data import RequestDataApp
from apps.schemas.task import ExecutorState, StepQueueItem
from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
FIXED_STEPS_BEFORE_START = [
{
LanguageType.CHINESE: Step(
name="理解上下文",
description="使用大模型,理解对话上下文",
node=SpecialCallType.SUMMARY.value,
type=SpecialCallType.SUMMARY.value,
),
LanguageType.ENGLISH: Step(
name="Understand context",
description="Use large model to understand the context of the dialogue",
node=SpecialCallType.SUMMARY.value,
type=SpecialCallType.SUMMARY.value,
),
}
]
FIXED_STEPS_AFTER_END = [
{
LanguageType.CHINESE: Step(
name="记忆存储",
description="理解对话答案,并存储到记忆中",
node=SpecialCallType.FACTS.value,
type=SpecialCallType.FACTS.value,
),
LanguageType.ENGLISH: Step(
name="Memory storage",
description="Understand the answer of the dialogue and store it in the memory",
node=SpecialCallType.FACTS.value,
type=SpecialCallType.FACTS.value,
),
}
]
class FlowExecutor(BaseExecutor):
"""用于执行工作流的Executor"""
flow: Flow
flow_id: str = Field(description="Flow ID")
question: str = Field(description="用户输入")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
current_step: StepQueueItem | None = Field(
description="当前执行的步骤",
default=None
)
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
if (
self.task.state
and self.task.state.flow_status != FlowStatus.INIT
and self.task.state.flow_status != FlowStatus.UNKNOWN
):
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
else:
self.task.state = ExecutorState(
flow_id=str(self.flow_id),
flow_name=self.flow.name,
flow_status=FlowStatus.RUNNING,
description=str(self.flow.description),
step_status=StepStatus.RUNNING,
app_id=str(self.post_body_app.app_id),
step_id="start",
step_name="开始" if self.task.language == LanguageType.CHINESE else "Start",
)
self.validate_flow_state(self.task)
self._reached_end: bool = False
self.step_queue: deque[StepQueueItem] = deque()
async def _invoke_runner(self) -> None:
"""单一Step执行"""
step_runner = StepExecutor(
msg_queue=self.msg_queue,
task=self.task,
step=self.current_step,
background=self.background,
question=self.question,
)
await step_runner.init()
await step_runner.run()
self.task = step_runner.task
async def _step_process(self) -> None:
"""执行当前queue里面的所有步骤(在用户看来是单一Step)"""
while True:
try:
self.current_step = self.step_queue.pop()
except IndexError:
break
await self._invoke_runner()
async def _find_next_id(self, step_id: str) -> list[str]:
"""查找下一个节点"""
next_ids = []
for edge in self.flow.edges:
if edge.edge_from == step_id:
next_ids += [edge.edge_to]
return next_ids
async def _find_flow_next(self) -> list[StepQueueItem]:
"""在当前步骤执行前,尝试获取下一步"""
if self.task.state.step_id == "end" or not self.task.state.step_id:
return []
if self.current_step.step.type == SpecialCallType.CHOICE.value:
branch_id = self.task.context[-1].output_data["branch_id"]
if branch_id:
next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id)
logger.info("[FlowExecutor] 分支ID:%s", branch_id)
else:
logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表")
return []
else:
next_steps = await self._find_next_id(self.task.state.step_id)
if not next_steps:
return [
StepQueueItem(
step_id="end",
step=self.flow.steps["end"],
),
]
logger.info("[FlowExecutor] 下一步:%s", next_steps)
return [
StepQueueItem(
step_id=next_step,
step=self.flow.steps[next_step],
)
for next_step in next_steps
]
async def run(self) -> None:
"""
运行流,返回各步骤结果,直到无法继续执行
数据通过向Queue发送消息的方式传输
"""
logger.info("[FlowExecutor] 运行工作流")
await self.push_message(EventType.FLOW_START.value)
first_step = StepQueueItem(
step_id=self.task.state.step_id,
step=self.flow.steps[self.task.state.step_id],
)
for step in FIXED_STEPS_BEFORE_START:
self.step_queue.append(
StepQueueItem(
step_id=str(uuid.uuid4()),
step=step.get(self.task.language, step[LanguageType.CHINESE]),
enable_filling=False,
to_user=False,
)
)
await self._step_process()
self.step_queue.append(first_step)
self.task.state.flow_status = FlowStatus.RUNNING
is_error = False
while not self._reached_end:
if self.task.state.step_status == StepStatus.ERROR:
logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤")
self.step_queue.clear()
self.step_queue.appendleft(
StepQueueItem(
step_id=str(uuid.uuid4()),
step=Step(
name=(
"错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling"
),
description=(
"错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling"
),
node=SpecialCallType.LLM.value,
type=SpecialCallType.LLM.value,
params={
"user_prompt": LLM_ERROR_PROMPT[self.task.language].replace(
"{{ error_info }}",
self.task.state.error_info["err_msg"],
),
},
),
enable_filling=False,
to_user=False,
)
)
is_error = True
self._reached_end = True
await self._step_process()
next_step = await self._find_flow_next()
if not next_step:
self._reached_end = True
for step in next_step:
self.step_queue.append(step)
if is_error:
self.task.state.flow_status = FlowStatus.ERROR
else:
self.task.state.flow_status = FlowStatus.SUCCESS
for step in FIXED_STEPS_AFTER_END:
self.step_queue.append(
StepQueueItem(
step_id=str(uuid.uuid4()),
step=step.get(self.task.language, step[LanguageType.CHINESE]),
)
)
await self._step_process()
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time
if is_error:
await self.push_message(EventType.FLOW_FAILED.value)
else:
await self.push_message(EventType.FLOW_SUCCESS.value)