import asyncio
import json
import logging
from collections.abc import AsyncIterator
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse

from app.core.config import rate_limit_chat
from app.core.deps import get_current_user
from app.core.rate_limit import limiter
from app.schemas.chat import ChatRequest
from app.services import conversation_store
from app.services.chat_runtime import chat_runtime
from app.services.settings_store import resolve_agent

router = APIRouter(tags=["chat"], dependencies=[Depends(get_current_user)])
_log = logging.getLogger(__name__)


def _sse(payload: dict) -> str:
    return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"


async def _chat_event_stream(
    thread_id: str,
    message: str,
    ai: dict,
    system_prompt: str,
    request: Request,
) -> AsyncIterator[str]:
    cancelled = False
    try:
        thread = conversation_store.get_thread(thread_id)
        title = message.strip()[:32] or "新对话"
        if thread and thread.get("title") in ("新对话", "", None):
            conversation_store.touch_thread(thread_id, title=title)
        else:
            conversation_store.touch_thread(thread_id)

        turn_input_tokens = 0
        async for event in chat_runtime.stream_turn(thread_id, message, ai, system_prompt):
            if await request.is_disconnected():
                cancelled = True
                break
            usage = event.get("usage")
            if isinstance(usage, dict):
                turn_input_tokens = max(turn_input_tokens, int(usage.get("input") or 0))
            yield _sse(event)
        if cancelled or await request.is_disconnected():
            return
        compressed = False
        summary: str | None = None
        actual_input = turn_input_tokens if turn_input_tokens > 0 else None
        plan = await chat_runtime.plan_compression(
            thread_id, ai, system_prompt, actual_input_tokens=actual_input
        )
        if plan is not None:
            yield _sse({"compress_start": True})
            try:
                compressed, summary = await chat_runtime.maybe_compress_thread(
                    thread_id,
                    ai,
                    system_prompt,
                    plan=plan,
                    actual_input_tokens=actual_input,
                )
            except Exception:
                _log.exception("对话历史压缩失败(已忽略)")
            if compressed:
                yield _sse({"compress_done": True, "summary": summary})
        yield _sse({"done": True, "compressed": compressed})
    except asyncio.CancelledError:
        _log.info("chat stream cancelled thread_id=%s", thread_id)
        raise
    except ValueError as error:
        yield _sse({"error": str(error)})
    except Exception as error:
        _log.exception("chat stream error")
        yield _sse({"error": f"模型调用失败: {error!s}"})


@router.post("/api/chat")
async def chat(
    payload: ChatRequest,
    request: Request,
    user: Annotated[dict[str, Any], Depends(get_current_user)],
):
    limit, window = rate_limit_chat()
    if not limiter.allow(f"chat:{user['id']}", limit, window):
        raise HTTPException(status_code=429, detail="对话请求过于频繁,请稍后再试")
    if not chat_runtime.ready:
        raise HTTPException(status_code=503, detail="对话服务尚未就绪")

    thread = conversation_store.get_thread(payload.thread_id)
    if not thread:
        raise HTTPException(status_code=404, detail="未找到该会话")
    if thread.get("agent_id") != payload.agent_id:
        raise HTTPException(status_code=400, detail="会话与智能体不匹配")

    try:
        ai, agent = resolve_agent(payload.agent_id)
    except ValueError as exc:
        raise HTTPException(status_code=404, detail=str(exc)) from exc

    system_prompt = str(agent.get("system_prompt") or "")
    return StreamingResponse(
        _chat_event_stream(
            payload.thread_id,
            payload.message,
            ai,
            system_prompt,
            request,
        ),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )