from __future__ import annotations
from typing import Any
import pytest
from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import RunActionNode
pytestmark = pytest.mark.unit
class _Runtime:
def __init__(self, state: dict[str, Any]) -> None:
self.state = state
def get_global_state(self, key: str) -> Any:
return self.state.get(key)
def update_global_state(self, values: dict[str, Any]) -> None:
self.state.update(values)
@pytest.mark.asyncio
async def test_run_action_node_llm_budget_exhausted_saves_error(
monkeypatch, base_action
) -> None:
calls: list[dict[str, Any]] = []
def _fake_save_result(config, action, result_to_save, time_taken):
calls.append({"config": config, "action": action, "result": result_to_save})
return config
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes._save_result",
_fake_save_result,
)
runtime = _Runtime(
{
"action_start_time": 0.0,
"config": {"max_llm_calls_per_run": 0, "validator_agent": {}},
"retrieval_settings": {},
"new_found_evidence_ids": [],
"messages": [{"role": "user", "content": "q"}],
"action": base_action.model_dump(),
"max_llm_calls_per_run": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
}
)
out = await RunActionNode()._do_invoke({}, runtime, None)
assert out == {"next_node": "end_node"}
assert len(calls) == 1
assert "Exceeded number of llm calls" in calls[0]["result"]["termination"]
@pytest.mark.asyncio
async def test_run_action_node_try_again_routes_back_to_run_action(
monkeypatch, base_action
) -> None:
async def _fake_run_action(*args, **kwargs):
return {"success": False, "try_again": True, "messages": [{"role": "user", "content": "trimmed"}]}
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.run_action",
_fake_run_action,
)
runtime = _Runtime(
{
"action_start_time": 1.0,
"config": {"max_llm_calls_per_run": 5, "validator_agent": {}},
"retrieval_settings": {},
"new_found_evidence_ids": ["a"],
"messages": [{"role": "user", "content": "q"}],
"action": base_action.model_dump(),
"max_llm_calls_per_run": 5,
"total_input_tokens": 0,
"total_output_tokens": 0,
}
)
out = await RunActionNode()._do_invoke({}, runtime, None)
assert out == {"next_node": "run_action"}
assert runtime.get_global_state("messages")[0]["content"] == "trimmed"
@pytest.mark.asyncio
async def test_run_action_node_failure_saves_error_result(monkeypatch, base_action) -> None:
save_calls: list[dict[str, Any]] = []
async def _fake_run_action(*args, **kwargs):
return {
"success": False,
"error": "boom",
"messages": [{"role": "assistant", "content": "failed turn"}],
"next_node": "end_node",
}
def _fake_save_result(config, action, result_to_save, time_taken):
save_calls.append(result_to_save)
return config
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.run_action",
_fake_run_action,
)
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes._save_result",
_fake_save_result,
)
runtime = _Runtime(
{
"action_start_time": 0.0,
"config": {"max_llm_calls_per_run": 3, "validator_agent": {}},
"retrieval_settings": {},
"new_found_evidence_ids": [],
"messages": [{"role": "user", "content": "q"}],
"action": base_action.model_dump(),
"max_llm_calls_per_run": 3,
"total_input_tokens": 0,
"total_output_tokens": 0,
}
)
out = await RunActionNode()._do_invoke({}, runtime, None)
assert out == {"next_node": "end_node"}
assert len(save_calls) == 1
assert save_calls[0]["termination"] == "Error: boom"