from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import pytest

from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import FindActionSpaceNode

pytestmark = pytest.mark.unit


class _Runtime:
    def __init__(self, state: dict[str, Any]) -> None:
        self.state = state

    def get_global_state(self, key: str) -> Any:
        return self.state.get(key)

    def update_global_state(self, values: dict[str, Any]) -> None:
        self.state.update(values)


def _action_files(log_dir: Path) -> list[Path]:
    return sorted((log_dir / "Action").glob("action_*.json"))


@pytest.mark.asyncio
async def test_find_action_space_node_success_creates_action_file(
    monkeypatch, tmp_log_dir: Path, base_action, base_result
) -> None:
    base_result.messages = [
        {"role": "assistant", "content": "first"},
        {"role": "assistant", "content": "latest"},
    ]

    async def _fake_run_find_action_space(*args, **kwargs):
        return {
            "actions": [base_action],
            "total_input_tokens": 22,
            "total_output_tokens": 11,
            "success": True,
        }

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.run_find_action_space",
        _fake_run_find_action_space,
    )
    runtime = _Runtime(
        {
            "query": base_action.question,
            "state": base_action.state,
            "result": base_result,
            "config": {"llm_config": {"max_tries": 2}},
            "log_dir": str(tmp_log_dir),
            "total_input_tokens": 0,
            "total_output_tokens": 0,
        }
    )
    node = FindActionSpaceNode()

    actions = await node._do_invoke({}, runtime, None)

    assert actions is not None and len(actions) == 1
    assert actions[0].id == base_action.id
    assert runtime.get_global_state("actions")[0].id == base_action.id
    assert base_result.messages == [{"role": "assistant", "content": "latest"}]
    payload = json.loads(_action_files(tmp_log_dir)[0].read_text(encoding="utf-8"))
    assert payload["proposals"] == [base_action.proposal.direction]
    assert payload["action_ids"] == [base_action.id]


@pytest.mark.asyncio
async def test_find_action_space_node_failure_writes_error_action_file(
    monkeypatch, tmp_log_dir: Path, base_action
) -> None:
    async def _fake_run_find_action_space(*args, **kwargs):
        return {"success": False, "error": "llm down", "actions": []}

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.run_find_action_space",
        _fake_run_find_action_space,
    )
    runtime = _Runtime(
        {
            "query": base_action.question,
            "state": base_action.state,
            "result": None,
            "config": {"llm_config": {"max_tries": 2}},
            "log_dir": str(tmp_log_dir),
            "total_input_tokens": 0,
            "total_output_tokens": 0,
        }
    )
    node = FindActionSpaceNode()

    result = await node._do_invoke({}, runtime, None)

    assert result is None
    payload = json.loads(_action_files(tmp_log_dir)[0].read_text(encoding="utf-8"))
    assert payload["proposals"] == ["Error: llm down"]