from __future__ import annotations
from types import SimpleNamespace
import pytest
from openjiuwen_deepsearch.config.config import AgentConfig, SearchWorkflowConfig
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Result
from openjiuwen_deepsearch.framework.openjiuwen.agent.workflow import DeepSearchAgent
pytestmark = pytest.mark.unit
def _make_agent(tmp_log_dir) -> DeepSearchAgent:
agent = DeepSearchAgent()
agent.agent_config = AgentConfig()
agent.per_question_params = agent.agent_config.search_workflow_per_question_params.model_copy(
update={
"max_workers": 1,
"retry_count_on_empty_action_space": 0,
"time_limit": 120,
"actions_explored_limit": 0,
"fail_limit": 0,
"answer_mode_top_k": 1,
"provide_best_guess": False,
}
)
agent.search_config = SearchWorkflowConfig()
agent.query = "capital question"
agent.log_dir = str(tmp_log_dir)
agent.time_limit = 120
agent.tool_map = {}
agent.action_pool.log_dir = str(tmp_log_dir)
return agent
@pytest.mark.asyncio
async def test_run_internal_terminates_on_fail_limit(
monkeypatch, tmp_log_dir, base_action, base_state
) -> None:
agent = _make_agent(tmp_log_dir)
agent.per_question_params.fail_limit = 1
async def _fake_run_workflow(*, workflow, inputs):
if workflow == "init_state_1":
return SimpleNamespace(result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0})
if workflow == "find_action_1":
return SimpleNamespace(result={"actions": [base_action], "total_input_tokens": 0, "total_output_tokens": 0})
raise AssertionError(f"unexpected workflow: {workflow}")
async def _fake_state_creation_workflow(*args, **kwargs):
return SimpleNamespace(
result={
"result": None,
"config": {"fail_count": 1},
"total_input_tokens": 0,
"total_output_tokens": 0,
}
)
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
_fake_run_workflow,
)
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
_fake_state_creation_workflow,
)
final = await agent._run_internal()
assert final.termination == "fail_limit"
@pytest.mark.asyncio
async def test_run_internal_terminates_when_action_pool_depleted(
monkeypatch, tmp_log_dir, base_state
) -> None:
agent = _make_agent(tmp_log_dir)
agent.per_question_params.retry_count_on_empty_action_space = 0
async def _fake_run_workflow(*, workflow, inputs):
if workflow == "init_state_1":
return SimpleNamespace(result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0})
if workflow == "find_action_1":
return SimpleNamespace(result={"actions": [], "total_input_tokens": 0, "total_output_tokens": 0})
raise AssertionError(f"unexpected workflow: {workflow}")
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
_fake_run_workflow,
)
final = await agent._run_internal()
assert final.termination == "action_pool_depleted"
@pytest.mark.asyncio
async def test_run_internal_returns_answer_termination(
monkeypatch, tmp_log_dir, base_action, base_state
) -> None:
agent = _make_agent(tmp_log_dir)
async def _fake_run_workflow(*, workflow, inputs):
if workflow == "init_state_1":
return SimpleNamespace(result={"init_state": base_state, "total_input_tokens": 0, "total_output_tokens": 0})
if workflow == "find_action_1":
return SimpleNamespace(result={"actions": [base_action], "total_input_tokens": 0, "total_output_tokens": 0})
raise AssertionError(f"unexpected workflow: {workflow}")
async def _fake_state_creation_workflow(*args, **kwargs):
return SimpleNamespace(
result={
"result": Result(
previous_action_id=base_action.id,
messages=[{"role": "assistant", "content": "answer"}],
new_states=[],
found_answer="Paris",
),
"config": {"fail_count": 0},
"total_input_tokens": 0,
"total_output_tokens": 0,
}
)
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.Runner.run_workflow",
_fake_run_workflow,
)
monkeypatch.setattr(
"openjiuwen_deepsearch.framework.openjiuwen.agent.workflow.DeepSearchAgent.run_state_creation_workflow",
_fake_state_creation_workflow,
)
final = await agent._run_internal()
assert final.termination == "answer"
assert final.prediction == "Paris"