"""消息队列模块"""
import asyncio
import json
import logging
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any
from apps.llm import LLM
from apps.schemas.enum_var import EventType
from apps.schemas.message import (
HeartbeatData,
MessageBase,
MessageExecutor,
)
from apps.schemas.record import RecordMetadata
from apps.schemas.task import TaskData
from .encoder import UUIDEncoder
logger = logging.getLogger(__name__)
class MessageQueue:
"""使用asyncio.Queue实现的消息队列"""
_heartbeat_interval: float = 3.0
async def init(self) -> None:
"""
异步初始化消息队列
:param task: 任务
"""
self._queue = asyncio.Queue()
self._close = False
self._heartbeat_task = asyncio.get_event_loop().create_task(self._heartbeat())
async def push_output(self, task: TaskData, llm: LLM, event_type: str, data: dict[str, Any]) -> None:
"""组装用于向用户(前端/Shell端)输出的消息"""
if event_type == EventType.DONE.value:
await self._queue.put("[DONE]")
return
step_time = round((datetime.now(UTC).timestamp() - task.runtime.time), 3)
step_time = max(step_time, 0)
metadata = RecordMetadata(
timeCost=step_time,
inputTokens=llm.input_tokens,
outputTokens=llm.output_tokens,
)
if task.state:
flow = MessageExecutor(
appId=task.state.appId,
executorId=task.state.executorId,
executorName=task.state.executorName,
executorStatus=task.state.executorStatus,
stepId=task.state.stepId,
stepName=task.state.stepName,
stepStatus=task.state.stepStatus,
stepType=task.state.stepType,
)
else:
flow = None
message = MessageBase(
event=event_type,
id=task.metadata.id,
conversationId=task.metadata.conversationId,
metadata=metadata,
flow=flow,
content=data,
)
await self._queue.put(
json.dumps(message.model_dump(by_alias=True), ensure_ascii=False, cls=UUIDEncoder),
)
async def get(self) -> AsyncGenerator[str, None]:
"""从Queue中获取消息;变为async generator"""
while True:
if self._close and self._queue.empty():
break
try:
message = self._queue.get_nowait()
yield message
except asyncio.QueueEmpty:
await asyncio.sleep(0.02)
except asyncio.CancelledError:
break
except Exception:
logger.exception("[Queue] 获取消息失败")
break
async def _heartbeat(self) -> None:
"""组装用于向用户(前端/Shell端)输出的心跳"""
heartbeat_template = HeartbeatData()
heartbeat_msg = json.dumps(heartbeat_template.model_dump(by_alias=True), ensure_ascii=False, cls=UUIDEncoder)
while True:
if self._close:
break
await asyncio.sleep(self._heartbeat_interval)
await self._queue.put(heartbeat_msg)
async def close(self) -> None:
"""关闭消息队列"""
self._close = True
self._heartbeat_task.cancel()