import logging
import json
from contextvars import Context
from unittest.mock import AsyncMock, Mock, patch
import pytest
from openjiuwen.core.session.node import Session
from openjiuwen.core.workflow.base import WorkflowCard
from openjiuwen.core.workflow.workflow import Workflow
from openjiuwen_deepsearch.framework.openjiuwen.agent.base_node import BaseNode
from openjiuwen_deepsearch.framework.openjiuwen.agent.editor_team_manager_node import EditorTeamNode
from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import (
EndNode,
EntryNode,
IntentRecognitionNode,
FeedbackHandlerNode,
OutlineInteractionNode,
StartNode,
UserFeedbackProcessorNode,
)
from openjiuwen_deepsearch.algorithm.query_understanding.intent_recognition import IntentRecognitionResult
from openjiuwen_deepsearch.framework.openjiuwen.agent.reasoning_writing_graph.editor_team_nodes import \
build_editor_team_workflow
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import (
Outline,
ResearchIntent,
Section,
)
from openjiuwen_deepsearch.framework.openjiuwen.agent.workflow import DeepresearchAgent
from openjiuwen_deepsearch.utils.constants_utils.node_constants import NodeId
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import session_context
from tests.utils.mock_config import get_default_agent_config
logger = logging.getLogger(__name__)
async def _run_agent_with_mocks(pre_handle_return, sub_graph_return):
with patch.object(EditorTeamNode, '_pre_handle', return_value=pre_handle_return):
with patch.object(
EditorTeamNode,
'_run_section_sub_graph_await',
new_callable=AsyncMock,
return_value=sub_graph_return
):
agent = MockAgent()
agent_config = get_default_agent_config()
chunks = []
async for chunk in agent.run(
message='杭州的天气怎么样',
conversation_id="default_session_id",
report_template="",
interrupt_feedback="",
agent_config=agent_config
):
chunks.append(chunk)
return chunks
@pytest.mark.asyncio
async def test_agent_node_missing_outline(caplog):
"""异常情况:缺少 outline 字段"""
mocked_pre_handle = {
"messages": "杭州的天气怎么样"
}
mocked_sub_graph = (["info1"], "report", [], [])
with caplog.at_level(logging.ERROR):
await _run_agent_with_mocks(mocked_pre_handle, mocked_sub_graph)
assert any("outline" in record.message and record.levelno == logging.ERROR
for record in caplog.records)
@pytest.mark.asyncio
async def test_agent_node_missing_sections(caplog):
"""异常情况:outline 存在但缺少 sections"""
mocked_pre_handle = {
"messages": "杭州的天气怎么样",
"outline": Outline(
language='zh-CN',
thought='...',
title='报告',
)
}
mocked_sub_graph = (["info1"], "report", [], [])
with caplog.at_level(logging.ERROR):
await _run_agent_with_mocks(mocked_pre_handle, mocked_sub_graph)
assert any("sections" in record.message and record.levelno == logging.ERROR
for record in caplog.records)
@pytest.mark.asyncio
async def test_agent_node_with_interrupt_feedback():
agent = MockAgent()
agent_config = get_default_agent_config()
async for chunk in agent.run(message='杭州的天气怎么样', conversation_id="default_session_id",
report_template="", interrupt_feedback="accepted",
agent_config=agent_config):
logger.debug("[Stream message from node: %s]", chunk)
def test_create_section_state():
editor_team_node = TestEditorTeamNode()
outline_section = Section(
title='当前天气状况',
description='提供杭州市当前的天气情况,包括天气状况描述、当前温度、风力风向、湿度等基本气象数据。',
is_core_section=False)
outline = Outline(
id="1", thought="mock though", title="mock title", sections=[outline_section]
)
search_context = {'session_id': 'default_session_id', 'original_query': '杭州的天气怎么样',
'research_query': '杭州的天气怎么样', 'messages': [{...}],
'language': 'zh-CN', 'plan_executed_num': 0, 'current_plan': None,
'duplicated_search_queries': {}, 'duplicated_search_items': {}, 'final_report_path': '',
'final_result': {'response_content': '', 'citation_messages': {}, 'exception_info': ''},
'report_generated_num': 0, 'report_evaluation': '',
'sub_report_content': '', 'evaluation_details': '', 'sub_evaluation_details': '',
'sub_report_evaluate_num': 0, 'sub_evaluation_result': '', 'report_template': '', 'questions': '',
'user_feedback': '', 'current_node': None, 'answer': '', 'answer_generated_num': 0,
'answer_evaluation': '', 'current_outline': Outline(language='zh-CN',
thought='用户需要杭州当前天气情况的详细信息,包括温度、湿度、风力、天气状况等关键数据。根据任务规划文档,我需要创建一个包含这些关键信息的结构化报告大纲。',
title='杭州实时天气情况报告',
sections=[Section(
title='当前天气状况',
description='提供杭州当前的天气状况描述,包括天气现象(晴、雨、多云等)和能见度等基本信息',
is_core_section=False),
Section(title='温度与湿度数据',
description='详细记录杭州当前的气温和相对湿度数据,包括体感温度和舒适度指数',
is_core_section=False),
Section(title='风力与风向信息',
description='记录当前风力等级、风向和风速数据,以及阵风情况',
is_core_section=False),
Section(title='空气质量指数',
description='提供杭州当前的空气质量指数(AQI)和主要污染物浓度数据',
is_core_section=False),
Section(title='天气预报摘要',
description='提供未来24小时的天气预报概要,包括温度变化趋势和天气变化预测',
is_core_section=False)]),
'outline_executed_num': 0, 'report_task': '', 'section_task': '', 'section_description': '',
'section_idx': 0, 'section_iscore': False, 'sub_section_outline': '',
'sub_section_references': [], 'classified_content': [], 'sub_section_core_content': [],
'search_mode': 'research', 'current_step': None, 'planner_agent_messages': None,
'source_tracer': '', 'trace_source_datas': [], 'merged_trace_source_datas': [],
'all_classified_contents': [], 'doc_infos': [], 'gathered_info': [],
'debug_pre_step': 'outline-c615f84c-d865-41f6-b7c3-354703c51732', 'go_deepsearch': True,
'debug_cur_step': 'outline-c615f84c-d865-41f6-b7c3-354703c51732'}
editor_team_node.create_section_state_from_state(search_context, outline, outline_section)
class TestEditorTeamNode(EditorTeamNode):
async def run_section_sub_graph(self, workflow_session, sub_workflow, input_state):
return await self._run_section_sub_graph_await(
workflow_session, sub_workflow, input_state
)
def pre_handle(self, inputs, session, context):
return self._pre_handle(inputs, session, context)
def create_section_state_from_state(self, state, outline, section):
return self._create_section_state_from_state(state, outline, section)
class _SessionAwareNode(BaseNode):
"""用于验证 BaseNode.invoke 会注入 session_context 的测试节点。"""
def _pre_handle(self, inputs, session, context):
"""测试节点不需要预处理,直接返回空输入。"""
return {}
async def _do_invoke(self, inputs, session, context):
"""读取 session_context 并回传是否为当前 session。"""
return {"same_session": session_context.get() is session}
def _post_handle(self, inputs, algorithm_output, session, context):
"""测试节点无需后处理,直接透传结果。"""
return algorithm_output
@pytest.mark.asyncio
async def test_run_sub_graph():
try:
editor_team_node = TestEditorTeamNode()
sub_workflow = build_editor_team_workflow()
workflow_session = AsyncMock(spec=Session)
await editor_team_node.run_section_sub_graph(workflow_session, sub_workflow, {})
except Exception as e:
logger.error(f"fail to test_run_sub_graph: {e}")
@pytest.mark.asyncio
async def test_base_node_invoke_injects_session_context():
"""验证 BaseNode.invoke 会在执行前注入当前 session_context。"""
node = _SessionAwareNode()
session = AsyncMock(spec=Session)
token = session_context.set(object())
try:
output = await node.invoke({}, session, Context())
finally:
session_context.reset(token)
assert output["same_session"] is True
def test_entry_node_routes_to_outline_after_intent_recognition():
"""验证 EntryNode 保留入口逻辑,并在意图识别后按 HITL 配置路由到大纲。"""
session = Mock(spec=Session)
session.get_global_state.return_value = False
session.update_global_state = Mock()
node = EntryNode()
with patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.add_debug_log_wrapper"
):
output = node._post_handle(
{},
{
"go_deepsearch": True,
"lang": "zh-CN",
"llm_result": "",
"error_msg": "",
"entry_search_results": [{"title": "result"}],
},
session,
Context(),
)
assert output["next_node"] == NodeId.OUTLINE.value
session.update_global_state.assert_any_call({
"search_context.entry_search_results": [{"title": "result"}]
})
@pytest.mark.asyncio
async def test_intent_recognition_node_updates_context_and_routes_to_entry():
"""验证独立意图识别节点先写回上下文,再交给 EntryNode 保留原入口逻辑。"""
session = AsyncMock(spec=Session)
original_query = "请写一份正式报告:AI Agent 趋势"
messages = [{"role": "user", "content": original_query}]
intent_result = IntentRecognitionResult(
original_query=original_query,
research_query="AI Agent 趋势",
research_intent=ResearchIntent(
section_count=5,
audience_role="研发负责人",
tone="formal",
include_domains=["example.com"],
exclude_domains=["bad.com"],
),
)
web_search_engine_config = Mock()
web_search_engine_config.search_engine_name = "tavily"
def _get_global_state(key):
return {
"search_context.original_query": original_query,
"search_context.messages": messages,
"config.web_search_engine_config": web_search_engine_config,
"config.workflow_human_in_the_loop": False,
}.get(key)
session.get_global_state.side_effect = _get_global_state
session.update_global_state = Mock()
node = IntentRecognitionNode()
with patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.adapt_llm_model_name",
return_value="basic",
), patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.recognize_report_intent",
new_callable=AsyncMock,
return_value=intent_result,
) as mock_recognize, patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.apply_web_search_domain_constraints",
) as mock_apply_domain_constraints:
output = await node.invoke({}, session, Context())
assert output["next_node"] == NodeId.ENTRY.value
mock_recognize.assert_awaited_once_with({
"original_query": original_query,
"messages": messages,
"llm_model_name": "basic",
})
update_payloads = [call.args[0] for call in session.update_global_state.call_args_list]
intent_update = next(
payload for payload in update_payloads
if "search_context.original_query" in payload and "search_context.research_query" in payload
)
assert intent_update["search_context.original_query"] == original_query
assert intent_update["search_context.research_query"] == "AI Agent 趋势"
assert intent_update["search_context.research_intent"] == intent_result.research_intent.model_dump()
assert "search_context.report_type_policy" in intent_update
session.update_global_state.assert_any_call({
"search_context.messages": [{"role": "user", "content": "AI Agent 趋势"}]
})
mock_apply_domain_constraints.assert_called_once_with(
search_engine_name="tavily",
include_domains=["example.com"],
exclude_domains=["bad.com"],
)
@pytest.mark.asyncio
async def test_pre_handle():
try:
editor_team_node = TestEditorTeamNode()
workflow_session = AsyncMock()
await editor_team_node.pre_handle({}, workflow_session, Context())
except Exception as e:
logger.error(f"fail to test_pre_handle: {e}")
@pytest.mark.asyncio
async def test_start_node_merges_agent_llm_timeouts_into_session_config():
"""验证 StartNode 会将 agent_llm_timeouts 合并进 session 配置快照。
Returns:
None.
"""
node = StartNode()
session = Mock()
session.update_global_state = Mock()
await node.invoke(
{
"query": "hello",
"thread_id": "thread-1",
"agent_config": {
"llm_config": {"general": {"model_name": "demo"}},
"web_search_engine_config": {"search_engine_name": "tavily"},
"local_search_engine_config": {"search_engine_name": "openapi"},
"agent_llm_timeouts": {"default": 300, "sub_reporter": 120},
},
},
session,
Context(),
)
search_context = session.update_global_state.call_args_list[0][0][0]["search_context"]
assert search_context["original_query"] == "hello"
assert search_context["research_query"] == "hello"
assert "query" not in search_context
merged_config = session.update_global_state.call_args_list[-1][0][0]["config"]
assert merged_config["agent_llm_timeouts"] == {"default": 300, "sub_reporter": 120}
@pytest.mark.asyncio
async def test_end_node_writes_workflow_llm_usage_when_stats_enabled():
"""验证 EndNode 在开启统计时会写入 workflow 级 token 汇总。"""
session = AsyncMock(spec=Session)
final_result = {"response_content": "ok", "exception_info": ""}
workflow_usage = {
"input_tokens": 10,
"output_tokens": 6,
"total_tokens": 16,
"llm_call_count": 3,
"agent_name_token_usage": [
{
"agent_name": "entry",
"input_tokens": 10,
"output_tokens": 6,
"total_tokens": 16,
"llm_call_count": 3,
}
],
}
def _get_global_state(key):
if key == "search_context.final_result":
return final_result
if key == "config.stats_info_llm":
return True
if key == "config.thread_id":
return "test-thread-id"
return None
session.get_global_state.side_effect = _get_global_state
session.write_custom_stream = AsyncMock()
session.update_global_state = Mock()
node = EndNode()
with patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.get_effective_workflow_llm_usage",
return_value=workflow_usage,
) as mock_get_workflow_usage:
output = await node.invoke({}, session, Context())
mock_get_workflow_usage.assert_called_once_with(session_id="test-thread-id", session=session)
session.update_global_state.assert_any_call(
{"search_context.final_result.workflow_llm_token_usage": workflow_usage}
)
result_data = json.loads(output["final_result"])
assert result_data["workflow_llm_token_usage"] == workflow_usage
@pytest.mark.asyncio
async def test_end_node_skips_workflow_llm_usage_when_stats_disabled():
"""验证 EndNode 在关闭统计时不会注入 workflow 级 token 汇总。"""
session = AsyncMock(spec=Session)
final_result = {"response_content": "ok", "exception_info": ""}
def _get_global_state(key):
if key == "search_context.final_result":
return final_result
if key == "config.stats_info_llm":
return False
if key == "config.thread_id":
return "test-thread-id"
return None
session.get_global_state.side_effect = _get_global_state
session.write_custom_stream = AsyncMock()
session.update_global_state = Mock()
node = EndNode()
with patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.get_effective_workflow_llm_usage"
) as mock_get_workflow_usage:
output = await node.invoke({}, session, Context())
mock_get_workflow_usage.assert_not_called()
result_data = json.loads(output["final_result"])
assert "workflow_llm_token_usage" not in result_data
@pytest.mark.asyncio
async def test_end_node_falls_back_to_persisted_usage_when_local_empty():
"""验证 EndNode 在本地累计为空时会回退 session 快照。"""
session = AsyncMock(spec=Session)
final_result = {"response_content": "ok", "exception_info": ""}
persisted_usage = {
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"llm_call_count": 2,
"agent_name_token_usage": [
{
"agent_name": "outline",
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"llm_call_count": 2,
}
],
}
def _get_global_state(key):
if key == "search_context.final_result":
return final_result
if key == "config.stats_info_llm":
return True
if key == "config.thread_id":
return "test-thread-id"
if key == "search_context.final_result.workflow_llm_token_usage":
return persisted_usage
return None
session.get_global_state.side_effect = _get_global_state
session.write_custom_stream = AsyncMock()
session.update_global_state = Mock()
node = EndNode()
local_empty_usage = {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
"llm_call_count": 0,
"agent_name_token_usage": [],
}
with patch(
"openjiuwen_deepsearch.utils.common_utils.llm_utils.get_workflow_llm_usage",
return_value=local_empty_usage,
):
output = await node.invoke({}, session, Context())
result_data = json.loads(output["final_result"])
assert result_data["workflow_llm_token_usage"] == persisted_usage
@pytest.mark.asyncio
@pytest.mark.parametrize(
("node_factory", "method_name", "method_args", "interact_payload", "assert_output", "thread_id"),
[
(
FeedbackHandlerNode,
"_get_user_feedback",
("web",),
'{"feedback":"ok"}',
lambda output: output == "ok",
"tid-1",
),
(
OutlineInteractionNode,
"_get_user_input",
("web", "1"),
'{"interrupt_feedback":"accepted","feedback":"ok"}',
lambda output: output["interrupt_feedback"] == "accepted",
"tid-2",
),
(
UserFeedbackProcessorNode,
"_get_user_feedback",
("web",),
'{"action":"finish"}',
lambda output: output == '{"action":"finish"}',
"tid-3",
),
],
)
async def test_web_interaction_nodes_persist_usage_before_interact(
node_factory,
method_name,
method_args,
interact_payload,
assert_output,
thread_id,
):
"""验证各类 web 交互节点在 interact 前会持久化 token 累计。"""
session = AsyncMock(spec=Session)
session.get_global_state.side_effect = lambda key: {
"config.stats_info_llm": True,
"config.thread_id": thread_id,
}.get(key)
session.interact = AsyncMock(return_value=interact_payload)
node = node_factory()
with patch(
"openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes.save_workflow_llm_usage_to_session"
) as mock_save:
output = await getattr(node, method_name)(*method_args, session)
assert assert_output(output)
mock_save.assert_called_once_with(session=session, session_id=thread_id)
class MockAgent(DeepresearchAgent):
def __init__(self):
super().__init__()
def _build_research_workflow(self, has_template=False):
_id = self.research_name
name = self.research_name
version = self.version
card = WorkflowCard(
id=_id,
version=version,
name=name,
)
flow = Workflow(card=card)
flow.set_start_comp(
start_comp_id=NodeId.START.value,
component=StartNode(),
inputs_schema=self.startnode_input_schema
)
flow.add_workflow_comp(NodeId.EDITOR_TEAM.value, EditorTeamNode())
flow.set_end_comp(NodeId.END.value, EndNode())
flow.add_connection(NodeId.START.value, NodeId.EDITOR_TEAM.value)
flow.add_connection(NodeId.EDITOR_TEAM.value, NodeId.END.value)
return flow