"""Orchestrator-level integration tests (mocked Runner / state_creation)."""

from __future__ import annotations

import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any

import pytest

from openjiuwen_deepsearch.config.config import AgentConfig, SearchWorkflowConfig
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import (
    Action,
    ActionProposal,
    Result,
    State,
)
from openjiuwen_deepsearch.framework.openjiuwen.agent.workflow import DeepSearchAgent

pytestmark = pytest.mark.integration


def _make_agent(tmp_log_dir: Path, **pqp_updates: Any) -> DeepSearchAgent:
    agent = DeepSearchAgent()
    agent.agent_config = AgentConfig()
    base_pqp = agent.agent_config.search_workflow_per_question_params.model_copy(
        update={
            "max_workers": 2,
            "retry_count_on_empty_action_space": 0,
            "time_limit": 300,
            "actions_explored_limit": 0,
            "fail_limit": 0,
            "answer_mode_top_k": 1,
            "provide_best_guess": False,
        }
    )
    agent.per_question_params = base_pqp.model_copy(update=pqp_updates)
    agent.search_config = SearchWorkflowConfig()
    agent.query = "integration query"
    agent.log_dir = str(tmp_log_dir)
    agent.time_limit = 120
    agent.tool_map = {}
    agent.action_pool.log_dir = str(tmp_log_dir)
    return agent


def _second_action(base_action: Action, *, action_id: str, strength: float, answer: str) -> Action:
    st = base_action.state.model_copy(
        update={
            "state": [
                base_action.state.state[0].model_copy(
                    update={"candidate": answer, "candidate_strength": strength}
                )
            ]
        }
    )
    return base_action.model_copy(
        update={
            "id": action_id,
            "proposal": ActionProposal(direction=f"dir-{action_id}", score=0.5),
            "state": st,
        }
    )


@pytest.mark.asyncio
async def test_actions_explored_limit_terminates(
    monkeypatch: pytest.MonkeyPatch, tmp_log_dir: Path, base_action, base_state
) -> None:
    agent = _make_agent(tmp_log_dir, actions_explored_limit=2, max_workers=2)
    a2 = base_action.model_copy(update={"id": "action-2", "proposal": ActionProposal(direction="b", score=0.5)})

    async def _fake_run_workflow(*, workflow: str, inputs: dict) -> SimpleNamespace:
        if workflow == "init_state_1":
            return SimpleNamespace(
                result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0}
            )
        if workflow == "find_action_1":
            return SimpleNamespace(
                result={"actions": [base_action, a2], "total_input_tokens": 0, "total_output_tokens": 0}
            )
        raise AssertionError(workflow)

    async def _fake_state_creation(*args: Any, **kwargs: Any) -> SimpleNamespace:
        return SimpleNamespace(
            result={
                "result": None,
                "config": {},
                "total_input_tokens": 0,
                "total_output_tokens": 0,
            }
        )

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
        _fake_run_workflow,
    )
    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
        _fake_state_creation,
    )

    final = await agent._run_internal()
    assert final.termination == "actions_explored_limit"


@pytest.mark.asyncio
async def test_new_states_triggers_second_find_action_then_answer(
    monkeypatch: pytest.MonkeyPatch, tmp_log_dir: Path, base_action, base_state
) -> None:
    agent = _make_agent(tmp_log_dir, max_workers=1)
    find_calls: list[Any] = []

    branch_state = base_state.model_copy(update={"id": "branch-1", "depth": 1})
    follow_action = base_action.model_copy(
        update={
            "id": "follow-up",
            "state": branch_state,
            "proposal": ActionProposal(direction="follow", score=0.6),
        }
    )

    async def _fake_run_workflow(*, workflow: str, inputs: dict) -> SimpleNamespace:
        if workflow == "init_state_1":
            return SimpleNamespace(
                result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0}
            )
        if workflow == "find_action_1":
            find_calls.append(inputs.get("state"))
            if len(find_calls) == 1:
                return SimpleNamespace(
                    result={"actions": [base_action], "total_input_tokens": 0, "total_output_tokens": 0}
                )
            return SimpleNamespace(
                result={"actions": [follow_action], "total_input_tokens": 0, "total_output_tokens": 0}
            )
        raise AssertionError(workflow)

    async def _fake_state_creation(*args: Any, **kwargs: Any) -> SimpleNamespace:
        action = kwargs.get("action") or args[0]
        aid = action.get("id") if isinstance(action, dict) else action.id
        if aid == base_action.id:
            return SimpleNamespace(
                result={
                    "result": Result(
                        previous_action_id=str(aid),
                        messages=[{"role": "assistant", "content": "branch"}],
                        new_states=[branch_state],
                        found_answer=None,
                    ),
                    "config": {},
                    "total_input_tokens": 0,
                    "total_output_tokens": 0,
                }
            )
        return SimpleNamespace(
            result={
                "result": Result(
                    previous_action_id=str(aid),
                    messages=[{"role": "assistant", "content": "done"}],
                    new_states=[],
                    found_answer="Lyon",
                ),
                "config": {},
                "total_input_tokens": 0,
                "total_output_tokens": 0,
            }
        )

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
        _fake_run_workflow,
    )
    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
        _fake_state_creation,
    )

    final = await agent._run_internal()
    assert len(find_calls) >= 2
    second_find_state = find_calls[1]
    second_id = (
        second_find_state.id
        if isinstance(second_find_state, State)
        else second_find_state["id"]
    )
    assert second_id == branch_state.id
    assert final.termination == "answer"
    assert final.prediction == "Lyon"


@pytest.mark.asyncio
async def test_answer_mode_top_k_returns_best_strength(
    monkeypatch: pytest.MonkeyPatch, tmp_log_dir: Path, base_action, base_state
) -> None:
    agent = _make_agent(tmp_log_dir, answer_mode_top_k=2, max_workers=2)
    weak = _second_action(base_action, action_id="weak", strength=0.2, answer="Paris")
    strong = _second_action(base_action, action_id="strong", strength=0.95, answer="Lyon")

    async def _fake_run_workflow(*, workflow: str, inputs: dict) -> SimpleNamespace:
        if workflow == "init_state_1":
            return SimpleNamespace(
                result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0}
            )
        if workflow == "find_action_1":
            return SimpleNamespace(
                result={"actions": [weak, strong], "total_input_tokens": 0, "total_output_tokens": 0}
            )
        raise AssertionError(workflow)

    async def _fake_state_creation(*args: Any, **kwargs: Any) -> SimpleNamespace:
        action = kwargs.get("action") or args[0]
        aid = action.get("id") if isinstance(action, dict) else action.id
        ans = "Paris" if aid == "weak" else "Lyon"
        return SimpleNamespace(
            result={
                "result": Result(
                    previous_action_id=str(aid),
                    messages=[{"role": "assistant", "content": ans}],
                    new_states=[],
                    found_answer=ans,
                ),
                "config": {},
                "total_input_tokens": 0,
                "total_output_tokens": 0,
            }
        )

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
        _fake_run_workflow,
    )
    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
        _fake_state_creation,
    )

    final = await agent._run_internal()
    assert final.termination == "answer"
    assert final.prediction == "Lyon"


@pytest.mark.asyncio
async def test_answer_writes_final_result_json(
    monkeypatch: pytest.MonkeyPatch, tmp_log_dir: Path, base_action, base_state
) -> None:
    agent = _make_agent(tmp_log_dir)

    async def _fake_run_workflow(*, workflow: str, inputs: dict) -> SimpleNamespace:
        if workflow == "init_state_1":
            return SimpleNamespace(
                result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0}
            )
        if workflow == "find_action_1":
            return SimpleNamespace(
                result={"actions": [base_action], "total_input_tokens": 0, "total_output_tokens": 0}
            )
        raise AssertionError(workflow)

    async def _fake_state_creation(*args: Any, **kwargs: Any) -> SimpleNamespace:
        return SimpleNamespace(
            result={
                "result": Result(
                    previous_action_id=base_action.id,
                    messages=[{"role": "assistant", "content": "ok"}],
                    new_states=[],
                    found_answer="Paris",
                ),
                "config": {},
                "total_input_tokens": 0,
                "total_output_tokens": 0,
            }
        )

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
        _fake_run_workflow,
    )
    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
        _fake_state_creation,
    )

    await agent._run_internal()
    out = tmp_log_dir / "final_result.json"
    assert out.is_file()
    data = json.loads(out.read_text(encoding="utf-8"))
    assert data["termination"] == "answer"
    assert data["prediction"] == "Paris"