"""MessageHandler unit tests."""

from collections.abc import AsyncIterator
from types import SimpleNamespace

import pytest

from jiuwenswarm.common.schema import Message
from jiuwenswarm.common.schema.message import ReqMethod
from jiuwenswarm.gateway.message_handler.message_handler import MessageHandler


class _FakeAgentClient:
    sent_requests: list[object] = []
    response_payload: dict[str, object] = {
        "event_type": "chat.interrupt_result",
        "message": "当前没有可取消的团队任务",
        "success": False,
    }

    @staticmethod
    async def send_request(env: object) -> SimpleNamespace:
        _FakeAgentClient.sent_requests.append(env)
        return SimpleNamespace(
            request_id="interrupt-1",
            channel_id="feishu_enterprise",
            ok=True,
            payload=dict(_FakeAgentClient.response_payload),
            metadata=None,
        )

    @staticmethod
    async def send_request_stream(env: object) -> AsyncIterator[object]:
        if False:
            yield env


class _TestMessageHandler(MessageHandler):
    @classmethod
    def create(cls) -> "_TestMessageHandler":
        setattr(MessageHandler, "_instance", None)
        setattr(cls, "_instance", None)
        _FakeAgentClient.sent_requests = []
        return cls(_FakeAgentClient())

    def seed_pending_evolution_approval(
        self,
        session_id: str,
        request_id: str,
    ) -> None:
        marker = getattr(self, "_mark_pending_evolution_approval")
        marker(session_id, request_id)

    def seed_session_evolution_in_progress(self, session_id: str) -> None:
        marker = getattr(self, "_mark_session_evolution_in_progress")
        marker(session_id)

    def seed_queued_supplement_input(
        self,
        session_id: str,
        payload: dict[str, object],
    ) -> None:
        queued_inputs = getattr(self, "_queued_supplement_input")
        queued_inputs[session_id] = payload

    async def handle_evolution_chunk(
        self,
        chunk: SimpleNamespace,
        session_id: str,
        request_metadata: dict[str, object] | None = None,
    ) -> None:
        handler = getattr(self, "_handle_evolution_chunk")
        await handler(chunk, session_id, request_metadata)

    def finish_evolution_approval_if_current(
        self,
        session_id: str,
        answered_request_id: str,
    ) -> dict[str, object] | None:
        finisher = getattr(self, "_finish_evolution_approval_if_current")
        return finisher(session_id, answered_request_id)

    def pending_evolution_approval(self, session_id: str) -> str | None:
        approvals = getattr(self, "_pending_evolution_approval")
        return approvals.get(session_id)

    def has_session_evolution_in_progress(self, session_id: str) -> bool:
        checker = getattr(self, "_is_session_evolution_in_progress")
        return checker(session_id)

    def queued_supplement_input(self, session_id: str) -> dict[str, object] | None:
        queued_inputs = getattr(self, "_queued_supplement_input")
        return queued_inputs.get(session_id)

    def pop_user_message_nowait(self):
        user_messages = getattr(self, "_user_messages")
        return user_messages.get_nowait()

    def should_emit_processing_status_for_stream(self, msg: Message) -> bool:
        return self._should_emit_processing_status_for_stream(msg)

    async def cancel_agent_work_for_session(
        self,
        msg: Message,
        old_sid: str | None,
        *,
        publish_interrupt_result: bool = True,
    ) -> None:
        await self._cancel_agent_work_for_session(
            msg,
            old_sid,
            publish_interrupt_result=publish_interrupt_result,
        )

    def build_queued_chat_send_message(
        self,
        msg: Message,
        new_input: str,
        original_request: str = "",
    ) -> Message:
        return self._build_queued_chat_send_message(
            msg,
            new_input,
            original_request=original_request,
        )

    def remember_user_query_context(self, msg: Message) -> None:
        self._remember_user_query_context(msg)

    def get_session_last_user_query(self, session_id: str) -> str:
        return self._get_session_last_user_query(session_id)


def _message(req_method: ReqMethod) -> Message:
    return Message(
        id="req-1",
        type="req",
        channel_id="web",
        session_id="sess-1",
        params={},
        timestamp=0,
        ok=True,
        req_method=req_method,
        is_stream=True,
    )


def _control_message() -> Message:
    return Message(
        id="control-1",
        type="req",
        channel_id="feishu_enterprise",
        session_id="sess-1",
        params={"mode": "team"},
        timestamp=0,
        ok=True,
        req_method=ReqMethod.CHAT_SEND,
        is_stream=False,
    )


def test_processing_status_is_only_emitted_for_chat_streams() -> None:
    handler = _TestMessageHandler.create()

    assert handler.should_emit_processing_status_for_stream(
        _message(ReqMethod.CHAT_SEND)
    ) is True
    assert handler.should_emit_processing_status_for_stream(
        _message(ReqMethod.HISTORY_GET)
    ) is False


def test_queued_supplement_message_instructs_todo_continuation():
    handler = _TestMessageHandler.create()
    msg = _message(ReqMethod.CHAT_CANCEL)

    queued = handler.build_queued_chat_send_message(
        msg,
        "删除 todo 列表里的提出改善意见",
        original_request=r"Analyze C:\repo\src\ui\screen-layout.ts",
    )

    assert queued.params["supplement_input"] == "删除 todo 列表里的提出改善意见"
    assert queued.params["original_request"] == r"Analyze C:\repo\src\ui\screen-layout.ts"
    assert r"C:\repo\src\ui\screen-layout.ts" in queued.params["query"]
    assert "继续执行当前会话 todo 列表中仍未完成" in queued.params["query"]
    assert "不要因为补充请求本身处理完成就询问用户下一步" in queued.params["query"]
    assert "上一轮正在输出的任务结果可能只展示了一部分" in queued.params["query"]
    assert "不要仅因为 todo 状态已经变为 completed 就跳过" in queued.params["query"]


def test_chat_send_query_context_is_remembered_for_supplement():
    handler = _TestMessageHandler.create()
    msg = _message(ReqMethod.CHAT_SEND)
    msg.params = {
        "query": r"Read C:\repo\src\ui\screen-layout.ts and summarize it",
    }

    handler.remember_user_query_context(msg)

    assert (
        handler.get_session_last_user_query("sess-1")
        == r"Read C:\repo\src\ui\screen-layout.ts and summarize it"
    )


@pytest.mark.asyncio
async def test_handle_evolution_chunk_auto_accepts_previous_pending_approval() -> None:
    handler = _TestMessageHandler.create()
    handler.seed_pending_evolution_approval("sess-1", "team_skill_evolve_old")

    chunk = SimpleNamespace(
        channel_id="web",
        request_id="stream-1",
        payload={
            "event_type": "chat.ask_user_question",
            "request_id": "team_skill_evolve_new",
            "questions": [{"header": "x"}],
        },
    )

    await handler.handle_evolution_chunk(chunk, "sess-1", {"k": "v"})

    assert handler.pending_evolution_approval("sess-1") == "team_skill_evolve_new"
    auto_msg = handler.pop_user_message_nowait()
    assert auto_msg.session_id == "sess-1"
    assert auto_msg.channel_id == "web"
    assert auto_msg.params["request_id"] == "team_skill_evolve_old"
    assert auto_msg.params["answers"] == [{"selected_options": ["接收"]}]
    assert auto_msg.metadata == {"k": "v"}


def test_finish_evolution_approval_if_current_keeps_newer_pending_request() -> None:
    handler = _TestMessageHandler.create()
    handler.seed_pending_evolution_approval("sess-2", "team_skill_evolve_new")
    handler.seed_session_evolution_in_progress("sess-2")
    handler.seed_queued_supplement_input("sess-2", {"new_input": "follow up"})

    queued = handler.finish_evolution_approval_if_current(
        "sess-2",
        "team_skill_evolve_old",
    )

    assert queued is None
    assert handler.pending_evolution_approval("sess-2") == "team_skill_evolve_new"
    assert handler.has_session_evolution_in_progress("sess-2") is True
    assert handler.queued_supplement_input("sess-2") == {"new_input": "follow up"}


@pytest.mark.asyncio
async def test_control_command_cancel_suppresses_interrupt_result() -> None:
    handler = _TestMessageHandler.create()

    await handler.cancel_agent_work_for_session(
        _control_message(),
        "sess-1",
        publish_interrupt_result=False,
    )

    assert len(_FakeAgentClient.sent_requests) == 1
    assert await handler.consume_robot_messages(timeout=0) is None


@pytest.mark.asyncio
async def test_default_cancel_publishes_interrupt_result() -> None:
    handler = _TestMessageHandler.create()

    await handler.cancel_agent_work_for_session(_control_message(), "sess-1")

    out = await handler.consume_robot_messages(timeout=0)
    assert out is not None
    assert out.payload == _FakeAgentClient.response_payload