import json
from unittest.mock import Mock, AsyncMock, patch
import pytest
from openjiuwen_deepsearch.algorithm.query_understanding.planner import Planner, PlannerResult, create_plan_tool
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Plan, StepType
async def _make_async_iter(chunks: list):
"""Helper function to create an async iterator from a list of chunks."""
for chunk in chunks:
yield chunk
test_data = {
'input': 'test input',
'section_idx': 1,
'plan_executed_num': 0,
'max_plan_executed_num': 3,
}
plan_response = Plan(
language="zh-CN",
title="Test Plan",
thought="This is a test thought",
is_research_completed=False,
steps=[
{
'title': 'Step 1',
'description': 'Description 1',
'type': StepType.INFO_COLLECTING.value,
'step_result': None
}
]
)
tool_name = create_plan_tool(test_data, 'planner').card.name
tool_call_id = '123'
functioncall_response = {
'content': '',
'name': None,
'raw_content': None,
'reason_content': None,
'role': 'assistant',
'tool_calls': [
{
'args': {
'is_research_completed': False,
'language': 'zh-CN',
'steps': [
{
'description': 'Description 1',
'title': 'Step 1',
'type': 'info_collecting'
},
],
'thought': 'This is a test thought',
'title': 'Test Plan'
},
'id': tool_call_id,
'name': tool_name,
'type': 'tool_call'
}
],
'usage_metadata': None
}
functioncall_result = {
'name': tool_name,
'role': 'tool',
'content': plan_response.model_dump_json(),
'tool_call_id': tool_call_id
}
class TestPlanner:
@pytest.fixture
def mock_llm(self):
return Mock()
@pytest.fixture
def setup_planner(self, mock_llm):
with patch('openjiuwen_deepsearch.algorithm.query_understanding.planner.llm_context', return_value=mock_llm):
planner = Planner()
return planner
@pytest.mark.asyncio
async def test_generate_plan_success(self, setup_planner, mock_llm):
"""测试成功生成计划"""
mock_llm_response = PlannerResult(
plan_success=True,
plan=plan_response,
response_messages=[functioncall_response, functioncall_result],
error_msg='')
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
return_value=functioncall_response
):
result = await setup_planner.generate_plan(test_data)
assert result == mock_llm_response
@pytest.mark.asyncio
async def test_generate_plan_retry_failure(self, setup_planner, mock_llm):
"""测试重试失败的情况"""
mock_llm_response = PlannerResult(
plan_success=False,
plan=None,
response_messages=[],
error_msg='section_idx: 1 | Round 1/3 | Error when Planner generating a plan. retry (1/1).error: Test Exception'
)
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
side_effect=Exception("Test Exception")
):
result = await setup_planner.generate_plan(test_data)
assert result == mock_llm_response
@pytest.mark.asyncio
async def test_generate_plan_exception(self, setup_planner, mock_llm):
"""测试生成计划时发生异常的情况"""
mock_llm_response = PlannerResult(
plan_success=False,
plan=None,
response_messages=[],
error_msg='section_idx: 1 | Round 1/3 | Error when Planner generating a plan. retry (1/1).error: Test Exception'
)
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
side_effect=Exception("Test Exception")
):
result = await setup_planner.generate_plan(test_data)
assert result == mock_llm_response
@pytest.mark.asyncio
async def test_generate_plan_max_retries(self, setup_planner, mock_llm):
"""测试达到最大重试次数的情况"""
mock_llm_response = PlannerResult(
plan_success=False,
plan=None,
response_messages=[],
error_msg='section_idx: 1 | Round 1/3 | Error when Planner generating a plan. retry (3/3).error: Test Exception'
)
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
side_effect=Exception("Test Exception")
):
setup_planner.config.max_retry_num = 3
result = await setup_planner.generate_plan(test_data)
assert result == mock_llm_response
@pytest.mark.asyncio
async def test_generate_plan_with_runtime_api_tool(self, setup_planner, mock_llm):
"""测试 planner 场景会合并并执行运行时 API 工具"""
custom_input = {
**test_data,
"api_tools_config": {
"query_understanding_tools": [
{
"tool_id": "tool-1",
"name": "runtime_plan_tool",
"description": "Runtime planner tool",
"path": "https://example.com/plan",
"http_method": "post",
"request_params": [
{
"name": "title",
"description": "plan title",
"send_method": "body",
"required": True,
},
{
"name": "language",
"description": "language",
"send_method": "body",
"required": False,
}
],
}
]
}
}
custom_response = {
**functioncall_response,
"tool_calls": [
{
"args": {
"language": "zh-CN",
"title": "Runtime Plan"
},
"id": tool_call_id,
"name": "runtime_plan_tool",
"type": "tool_call"
}
]
}
mock_http_response = Mock()
mock_http_response.headers = {}
mock_http_response.encoding = "utf-8"
mock_http_response.raise_for_status = Mock()
json_data = json.dumps({
"code": 0,
"message": "ok",
"data": {
"language": "zh-CN",
"title": "Runtime Plan",
"thought": "Generated by runtime api",
"is_research_completed": True,
"steps": [],
}
}).encode("utf-8")
mock_http_response.aiter_bytes = Mock(return_value=_make_async_iter([json_data]))
mock_stream_cm = Mock()
mock_stream_cm.__aenter__ = AsyncMock(return_value=mock_http_response)
mock_stream_cm.__aexit__ = AsyncMock(return_value=None)
mock_client = Mock()
mock_client.stream = Mock(return_value=mock_stream_cm)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
return_value=custom_response
) as mock_invoke, patch(
'openjiuwen_deepsearch.framework.openjiuwen.tools.runtime_api.runtime_api.validate_runtime_request_url',
return_value=None
), patch(
'openjiuwen_deepsearch.framework.openjiuwen.tools.runtime_api.runtime_api.httpx.AsyncClient',
return_value=mock_client
):
result = await setup_planner.generate_plan(custom_input)
tools = mock_invoke.await_args.kwargs["tools"]
tool_names = [
getattr(tool, "name", tool["name"] if isinstance(tool, dict) else None)
for tool in tools
]
assert result.plan_success is True
assert result.plan.title == "Runtime Plan"
assert result.response_messages[1]["name"] == "runtime_plan_tool"
assert tool_name in tool_names
assert "runtime_plan_tool" in tool_names
@pytest.mark.asyncio
async def test_generate_plan_rejects_steps_exceeding_max_step_num(self, setup_planner, mock_llm):
"""planner should fail when the model returns more steps than allowed."""
excessive_response = {
**functioncall_response,
"tool_calls": [
{
**functioncall_response["tool_calls"][0],
"args": {
**functioncall_response["tool_calls"][0]["args"],
"steps": [
{
"description": f"Description {idx}",
"title": f"Step {idx}",
"type": "info_collecting"
}
for idx in range(3)
],
},
}
],
}
with patch(
'openjiuwen_deepsearch.algorithm.query_understanding.planner.ainvoke_llm_with_stats',
new_callable=AsyncMock,
return_value=excessive_response
):
result = await setup_planner.generate_plan({**test_data, "max_step_num": 2})
assert result.plan_success is False
assert result.plan is None
assert "steps" in result.error_msg.lower()