from __future__ import annotations
import pytest
from openjiuwen_deepsearch.algorithm.search_nodes.run_action import (
RunActionConfig,
run_action,
)
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
pytestmark = pytest.mark.unit
def _params(base_action, *, strategy: str = "fail", retrieval_tool_only: bool = False):
return RunActionConfig(
llm_config={"model_name": "mock"},
config={
"use_candidate_strength": True,
"discovered_clues_mode": "report",
"context_limit_reached_strategy": strategy,
"retrieval_settings": {"top_k": 10, "top_k_multiply_factor": 8},
},
action=base_action.model_dump(),
state=base_action.state.model_dump(),
query=base_action.question,
messages=[{"role": "user", "content": "q"}],
new_found_evidence_ids=[],
validate_new_states=True,
validate_answer=True,
action_start_time=0.0,
retrieval_tool_only=retrieval_tool_only,
retrieval_settings={"top_k": 10, "top_k_multiply_factor": 8},
context_limit_reached_strategy=strategy,
total_input_tokens=1,
total_output_tokens=2,
)
@pytest.mark.asyncio
async def test_run_run_action_success_path(monkeypatch, base_action) -> None:
async def _fake_run_llm(*args, **kwargs):
return "<state>{}</state>", None, 3, 4
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.run_llm", _fake_run_llm
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.apply_system_prompt",
lambda *args, **kwargs: [{"role": "system", "content": "s"}],
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.parse_and_apply_llm_result_safe",
lambda *args, **kwargs: ("state", {"messages": []}, {"k": "v"}),
)
out = await run_action(_params(base_action))
assert out["success"] is True
assert out["mode"] == "state"
assert out["data"] == {"messages": []}
assert out["total_input_tokens"] == 4
assert out["total_output_tokens"] == 6
@pytest.mark.asyncio
async def test_run_action_retry_exhausted_returns_failure(monkeypatch, base_action) -> None:
async def _raise_retry(*args, **kwargs):
raise CustomValueException(
StatusCode.AGENT_RETRY_FAILED_ALL_ATTEMPTS.code,
"all retries exhausted",
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.run_llm", _raise_retry
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.apply_system_prompt",
lambda *args, **kwargs: [{"role": "system", "content": "s"}],
)
out = await run_action(_params(base_action))
assert out["success"] is False
assert out["next_node"] == "end_node"
assert "retries exhausted" in out["error"]
assert isinstance(out["messages"], list)
@pytest.mark.asyncio
async def test_run_action_context_limit_reduced_retrieval_strategy(
monkeypatch, base_action
) -> None:
async def _raise_context_limit(*args, **kwargs):
raise CustomValueException(
StatusCode.LLM_CALL_FAILED.code,
"context length exceeded",
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.run_llm", _raise_context_limit
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action.apply_system_prompt",
lambda *args, **kwargs: [{"role": "system", "content": "s"}],
)
monkeypatch.setattr(
"openjiuwen_deepsearch.algorithm.search_nodes.run_action._is_context_limit_error",
lambda *args, **kwargs: True,
)
out = await run_action(
_params(base_action, strategy="reduced_retrieval_request", retrieval_tool_only=True)
)
assert out["success"] is False
assert out["try_again"] is True
assert out["retrieval_settings"]["top_k"] == 5
assert out["retrieval_settings"]["top_k_multiply_factor"] == 4