from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import pytest

from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import InitializeStateNode

pytestmark = pytest.mark.unit


class _Runtime:
    def __init__(self, state: dict[str, Any]) -> None:
        self.state = state

    def get_global_state(self, key: str) -> Any:
        return self.state.get(key)

    def update_global_state(self, values: dict[str, Any]) -> None:
        self.state.update(values)


@pytest.mark.asyncio
async def test_initialize_state_node_writes_initial_state_file(
    monkeypatch, tmp_log_dir: Path, base_state
) -> None:
    async def _fake_run_initialize_state(*args, **kwargs):
        return {
            "init_state": base_state,
            "messages": [{"role": "assistant", "content": "ready"}],
            "total_input_tokens": 12,
            "total_output_tokens": 7,
        }

    monkeypatch.setattr(
        "openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.run_initialize_state",
        _fake_run_initialize_state,
    )
    runtime = _Runtime(
        {
            "query": "Where is the Eiffel Tower?",
            "config": {"llm_config": {"max_tries": 2}},
            "log_dir": str(tmp_log_dir),
            "total_input_tokens": 0,
            "total_output_tokens": 0,
        }
    )

    node = InitializeStateNode()
    result = await node._do_invoke({}, runtime, None)

    assert result is None
    assert runtime.get_global_state("init_state").id == base_state.id
    payload = json.loads((tmp_log_dir / "initial_state.json").read_text(encoding="utf-8"))
    assert payload["id"] == base_state.id
    assert payload["messages"][0]["content"] == "ready"


def test_initialize_state_pre_handle_requires_query(tmp_log_dir: Path) -> None:
    runtime = _Runtime({"query": "", "config": {}, "log_dir": str(tmp_log_dir)})
    node = InitializeStateNode()

    with pytest.raises(Exception):
        node._pre_handle({}, runtime, None)