from __future__ import annotations

from typing import Any

import pytest

from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import SearchStartNode

pytestmark = pytest.mark.unit


class _Runtime:
    def __init__(self) -> None:
        self.state: dict[str, Any] = {}

    def update_global_state(self, values: dict[str, Any]) -> None:
        self.state.update(values)

    def get_global_state(self, key: str) -> Any:
        return self.state.get(key)


def _node() -> SearchStartNode:
    return SearchStartNode()


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "workflow_name, expected_timeout, expected_max_tries",
    [
        ("init_state_workflow", 600, None),
        ("find_action_workflow", 600, 10),
        ("state_creation_workflow", 1200, 20),
    ],
)
async def test_search_start_node_applies_workflow_specific_llm_defaults(
    workflow_name,
    expected_timeout,
    expected_max_tries,
    agent_config_dict,
    search_config_dict,
) -> None:
    runtime = _Runtime()
    ac = {
        **agent_config_dict,
        "llm_config": {
            **agent_config_dict.get("llm_config", {}),
            "general": {
                **agent_config_dict.get("llm_config", {}).get("general", {}),
                "model_name": "m1",
                "model_type": "openai",
                "base_url": "https://example.com",
                "api_key": bytearray(b"x"),
            },
        },
        "retrieval_settings": {"top_k": 7},
        "log_dir": "/tmp/run-x",
        "fail_count": 4,
    }
    inputs = {
        "workflow_name": workflow_name,
        "agent_config": ac,
        "search_config": search_config_dict,
    }

    await _node().invoke(inputs, runtime, None)
    config = runtime.get_global_state("config")

    llm = config["llm_config"]["general"]
    assert llm["timeout"] == expected_timeout
    assert llm["append_think_tags_to_messages"] is (
        workflow_name == "state_creation_workflow"
    )
    if expected_max_tries is not None:
        assert llm["max_tries"] == expected_max_tries
    if workflow_name == "state_creation_workflow":
        assert config["retrieval_settings"]["top_k"] == 7
        assert config["log_dir"] == "/tmp/run-x"
        assert config["fail_count"] == 4


@pytest.mark.asyncio
async def test_search_start_node_rejects_unknown_workflow(agent_config_dict, search_config_dict) -> None:
    runtime = _Runtime()
    with pytest.raises(Exception):
        await _node().invoke(
            {
                "workflow_name": "unknown_workflow",
                "agent_config": agent_config_dict,
                "search_config": search_config_dict,
            },
            runtime,
            None,
        )