import asyncio
import json
import logging
from collections.abc import AsyncIterator
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.core.config import rate_limit_chat
from app.core.deps import get_current_user
from app.core.rate_limit import limiter
from app.schemas.chat import ChatRequest
from app.services import conversation_store
from app.services.chat_runtime import chat_runtime
from app.services.settings_store import resolve_agent
router = APIRouter(tags=["chat"], dependencies=[Depends(get_current_user)])
_log = logging.getLogger(__name__)
def _sse(payload: dict) -> str:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
async def _chat_event_stream(
thread_id: str,
message: str,
ai: dict,
system_prompt: str,
request: Request,
) -> AsyncIterator[str]:
cancelled = False
try:
thread = conversation_store.get_thread(thread_id)
title = message.strip()[:32] or "新对话"
if thread and thread.get("title") in ("新对话", "", None):
conversation_store.touch_thread(thread_id, title=title)
else:
conversation_store.touch_thread(thread_id)
turn_input_tokens = 0
async for event in chat_runtime.stream_turn(thread_id, message, ai, system_prompt):
if await request.is_disconnected():
cancelled = True
break
usage = event.get("usage")
if isinstance(usage, dict):
turn_input_tokens = max(turn_input_tokens, int(usage.get("input") or 0))
yield _sse(event)
if cancelled or await request.is_disconnected():
return
compressed = False
summary: str | None = None
actual_input = turn_input_tokens if turn_input_tokens > 0 else None
plan = await chat_runtime.plan_compression(
thread_id, ai, system_prompt, actual_input_tokens=actual_input
)
if plan is not None:
yield _sse({"compress_start": True})
try:
compressed, summary = await chat_runtime.maybe_compress_thread(
thread_id,
ai,
system_prompt,
plan=plan,
actual_input_tokens=actual_input,
)
except Exception:
_log.exception("对话历史压缩失败(已忽略)")
if compressed:
yield _sse({"compress_done": True, "summary": summary})
yield _sse({"done": True, "compressed": compressed})
except asyncio.CancelledError:
_log.info("chat stream cancelled thread_id=%s", thread_id)
raise
except ValueError as error:
yield _sse({"error": str(error)})
except Exception as error:
_log.exception("chat stream error")
yield _sse({"error": f"模型调用失败: {error!s}"})
@router.post("/api/chat")
async def chat(
payload: ChatRequest,
request: Request,
user: Annotated[dict[str, Any], Depends(get_current_user)],
):
limit, window = rate_limit_chat()
if not limiter.allow(f"chat:{user['id']}", limit, window):
raise HTTPException(status_code=429, detail="对话请求过于频繁,请稍后再试")
if not chat_runtime.ready:
raise HTTPException(status_code=503, detail="对话服务尚未就绪")
thread = conversation_store.get_thread(payload.thread_id)
if not thread:
raise HTTPException(status_code=404, detail="未找到该会话")
if thread.get("agent_id") != payload.agent_id:
raise HTTPException(status_code=400, detail="会话与智能体不匹配")
try:
ai, agent = resolve_agent(payload.agent_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
system_prompt = str(agent.get("system_prompt") or "")
return StreamingResponse(
_chat_event_stream(
payload.thread_id,
payload.message,
ai,
system_prompt,
request,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)