"""FastAPI 聊天接口"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import JSONResponse, StreamingResponse
from apps.common.queue import MessageQueue
from apps.common.wordscheck import WordsCheck
from apps.dependency import get_session, get_user
from apps.schemas.enum_var import FlowStatus
from apps.scheduler.scheduler import Scheduler
from apps.scheduler.scheduler.context import save_data
from apps.schemas.request_data import RequestData, RequestDataApp
from apps.schemas.response_data import ResponseData
from apps.schemas.enum_var import LanguageType
from apps.schemas.task import Task
from apps.services.activity import Activity
from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager
from apps.services.flow import FlowManager
from apps.services.conversation import ConversationManager
from apps.services.record import RecordManager
from apps.services.task import TaskManager
RECOMMEND_TRES = 5
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api",
tags=["chat"],
)
async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task:
"""初始化Task"""
if not post_body.group_id:
post_body.group_id = str(uuid.uuid4())
if post_body.task_id is None:
conversation = await ConversationManager.get_conversation_by_conversation_id(
user_sub=user_sub,
conversation_id=post_body.conversation_id,
)
if not conversation:
err = "[Chat] 用户没有权限访问该对话!"
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err)
task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id)
await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids)
task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body)
task.runtime.question = post_body.question
task.ids.group_id = post_body.group_id
task.state.app_id = post_body.app.app_id if post_body.app else ""
else:
if not post_body.task_id:
err = "[Chat] task_id 不可为空!"
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty")
task = await TaskManager.get_task_by_task_id(post_body.task_id)
post_body.app = RequestDataApp(appId=task.state.app_id)
post_body.group_id = task.ids.group_id
post_body.conversation_id = task.ids.conversation_id
post_body.language = task.language
post_body.question = task.runtime.question
task.language = post_body.language
return task
async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]:
"""进行实际问答,并从MQ中获取消息"""
try:
active_id = await Activity.set_active(user_sub)
if await WordsCheck().check(post_body.question) != 1:
yield "data: [SENSITIVE]\n\n"
logger.info("[Chat] 问题包含敏感词!")
await Activity.remove_active(active_id)
return
task = await init_task(post_body, user_sub, session_id)
task.ids.active_id = active_id
queue = MessageQueue()
await queue.init()
scheduler = Scheduler(task, queue, post_body)
scheduler_task = asyncio.create_task(scheduler.run())
async for content in queue.get():
if content[:6] == "[DONE]":
break
yield "data: " + content + "\n\n"
await scheduler_task
task = scheduler.task
if task.state.flow_status == FlowStatus.ERROR:
logger.error("[Chat] 生成答案失败")
yield "data: [ERROR]\n\n"
await Activity.remove_active(active_id)
return
if await WordsCheck().check(task.runtime.answer) != 1:
yield "data: [SENSITIVE]\n\n"
logger.info("[Chat] 答案包含敏感词!")
await Activity.remove_active(active_id)
return
await save_data(task, user_sub, post_body)
if post_body.app and post_body.app.flow_id:
await FlowManager.update_flow_debug_by_app_and_flow_id(
post_body.app.app_id,
post_body.app.flow_id,
debug=True,
)
yield "data: [DONE]\n\n"
except Exception:
logger.exception("[Chat] 生成答案失败")
yield "data: [ERROR]\n\n"
finally:
await Activity.remove_active(active_id)
@router.post("/chat")
async def chat(
post_body: RequestData,
user_sub: Annotated[str, Depends(get_user)],
session_id: Annotated[str, Depends(get_session)],
) -> StreamingResponse:
"""LLM流式对话接口"""
post_body.language = LanguageType.CHINESE if post_body.language in {"zh", LanguageType.CHINESE} else LanguageType.ENGLISH
if post_body.question is not None and not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question):
await UserBlacklistManager.change_blacklisted_users(user_sub, -10)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted")
res = chat_generator(post_body, user_sub, session_id)
return StreamingResponse(
content=res,
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
},
)
@router.post("/stop", response_model=ResponseData)
async def stop_generation(user_sub: Annotated[str, Depends(get_user)],
task_id: Annotated[str, Query(..., alias="taskId")] = "") -> JSONResponse:
"""停止生成"""
task = await TaskManager.get_task_by_task_id(task_id)
if task:
await Activity.remove_active(task.ids.active_id)
return JSONResponse(
status_code=status.HTTP_200_OK,
content=ResponseData(
code=status.HTTP_200_OK,
message="stop generation success",
result={},
).model_dump(exclude_none=True, by_alias=True),
)