from __future__ import annotations

from types import SimpleNamespace
from typing import Any

import pytest

from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import RunActionNode, SearchEndNode
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Result

pytestmark = pytest.mark.unit


class _Runtime:
    def __init__(self, state: dict[str, Any]) -> None:
        self.state = state
        self.outputs: dict[str, Any] = {}

    def get_global_state(self, key: str) -> Any:
        return self.state.get(key)

    def update_global_state(self, values: dict[str, Any]) -> None:
        self.state.update(values)

    def base(self) -> Any:
        outputs = self.outputs

        class _State:
            @staticmethod
            def set_outputs(payload: dict[str, Any]) -> None:
                outputs.update(payload)

        return SimpleNamespace(state=lambda: _State())


def test_run_action_post_handle_handoff_to_tool(base_action) -> None:
    runtime = _Runtime({})
    out = RunActionNode()._post_handle(
        {},
        {
            "success": True,
            "mode": "tool_calls",
            "data": {"tool_calls": [{"name": "web_search", "arguments": {"query": ["x"]}}], "messages": []},
            "config": {},
            "total_input_tokens": 1,
            "total_output_tokens": 1,
            "validate_new_states": False,
            "validate_answer": False,
        },
        runtime,
        None,
    )

    assert out == {"next_node": "tool"}
    assert runtime.get_global_state("pending_tool_calls")[0]["name"] == "web_search"


def test_run_action_post_handle_handoff_to_validate_for_state_result(base_action) -> None:
    runtime = _Runtime({})
    result = Result(
        previous_action_id=base_action.id,
        messages=[{"role": "assistant", "content": "state done"}],
        new_states=[],
        found_answer=None,
    )
    out = RunActionNode()._post_handle(
        {},
        {
            "success": True,
            "mode": "state",
            "data": result,
            "config": {},
            "total_input_tokens": 2,
            "total_output_tokens": 3,
            "validate_new_states": True,
            "validate_answer": False,
        },
        runtime,
        None,
    )

    assert out == {"next_node": "validate_new_state"}
    assert runtime.get_global_state("result").previous_action_id == base_action.id


@pytest.mark.asyncio
async def test_search_end_node_handoff_payload_for_state_creation(base_result) -> None:
    runtime = _Runtime(
        {
            "workflow_name": "state_creation_workflow",
            "result": base_result,
            "config": {"fail_count": 0},
            "messages": [{"role": "assistant", "content": "done"}],
            "total_input_tokens": 9,
            "total_output_tokens": 5,
        }
    )

    out = await SearchEndNode().invoke({}, runtime, None)

    assert "final_result" in out
    assert out["final_result"]["result"] == base_result
    assert out["final_result"]["total_input_tokens"] == 9
    assert out["final_result"]["total_output_tokens"] == 5