from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from openjiuwen_deepsearch.algorithm.source_trace.source_tracer import SourceTracer
from openjiuwen_deepsearch.framework.openjiuwen.agent.reasoning_writing_graph.editor_team_nodes import \
SubSourceTracerNode
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import SubReportContent
from openjiuwen_deepsearch.utils.constants_utils.node_constants import NodeId
class ExposedSubSourceTracerNode(SubSourceTracerNode):
"""用于测试的类,公开受保护的方法以遵循 G.CLS.11 规则"""
def _pre_handle(self, *args, **kwargs):
return self.pre_handle(*args, **kwargs)
def pre_handle(self, *args, **kwargs):
return super()._pre_handle(*args, **kwargs)
def _skip_trace_source_handle(self, *args, **kwargs):
return self.skip_trace_source_handle(*args, **kwargs)
def skip_trace_source_handle(self, *args, **kwargs):
return super()._skip_trace_source_handle(*args, **kwargs)
async def _do_invoke(self, *args, **kwargs):
return await self.do_invoke(*args, **kwargs)
async def do_invoke(self, *args, **kwargs):
return await super()._do_invoke(*args, **kwargs)
def _post_handle(self, *args, **kwargs):
return self.post_handle(*args, **kwargs)
def post_handle(self, *args, **kwargs):
return super()._post_handle(*args, **kwargs)
class TestSubSourceTracerNode:
"""Test cases for SubSourceTracerNode class."""
@pytest.fixture
def sub_source_tracer_node(self):
"""Fixture to create a SubSourceTracerNode instance."""
return ExposedSubSourceTracerNode()
@pytest.fixture
def mock_session(self):
"""Fixture to create a mock Session instance."""
session = MagicMock()
session.get_global_state = MagicMock()
session.update_global_state = MagicMock()
return session
@pytest.fixture
def mock_search_context(self):
"""Fixture to provide mock search context data."""
return {
"sub_report_content": "This is a test sub report content.",
"search_record": {
"web_page_search_record": [
{"title": "example", "url": "https://example.com", "content": "test content"}],
"web_image_search_record": [],
"local_text_search_record": [],
"local_image_search_record": []
},
"language": "zh-CN"
}
@staticmethod
def test_pre_handle(sub_source_tracer_node, mock_session, mock_search_context):
"""Test _pre_handle method with different configurations."""
sub_report_content_obj = SubReportContent(
sub_report_content_text=mock_search_context["sub_report_content"],
classified_content=[]
)
def get_global_state_side_effect_enabled(key):
if key == "config.source_tracer_research_trace_source_switch":
return True
elif key == "section_context.sub_report_content":
return sub_report_content_obj
elif key == "section_context.search_record":
return mock_search_context["search_record"]
elif key == "section_context.language":
return mock_search_context["language"]
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect_enabled
result = sub_source_tracer_node.pre_handle(None, mock_session, None)
assert result["research_trace_source_switch"] is True
assert result["report"] == mock_search_context["sub_report_content"]
assert result["language"] == mock_search_context["language"]
def get_global_state_side_effect_disabled(key):
if key == "config.source_tracer_research_trace_source_switch":
return False
return get_global_state_side_effect_enabled(key)
mock_session.get_global_state.side_effect = get_global_state_side_effect_disabled
result = sub_source_tracer_node.pre_handle(None, mock_session, None)
assert result["research_trace_source_switch"] is False
@staticmethod
def test_skip_trace_source_handle(sub_source_tracer_node, mock_session):
"""Test _skip_trace_source_handle method."""
with patch.object(sub_source_tracer_node, 'post_handle') as mock_post_handle:
mock_post_handle.return_value = {"next_node": NodeId.END.value}
current_inputs = {"report": "Test report"}
result = sub_source_tracer_node.skip_trace_source_handle(
None, mock_session, None, current_inputs)
mock_post_handle.assert_called_once()
args, kwargs = mock_post_handle.call_args
algorithm_output = args[1]
assert algorithm_output["trace_source_datas"] == []
assert algorithm_output["modified_report"] == current_inputs["report"]
@pytest.mark.asyncio
async def test_do_invoke_with_trace_enabled(self, sub_source_tracer_node, mock_session, mock_search_context):
"""Test _do_invoke method with trace source enabled."""
sub_report_content_obj = SubReportContent(
sub_report_content_text=mock_search_context["sub_report_content"],
classified_content=[]
)
def get_global_state_side_effect(key):
if key == "config.source_tracer_research_trace_source_switch":
return True
elif key == "section_context.sub_report_content":
return sub_report_content_obj
elif key == "section_context.search_record":
return mock_search_context["search_record"]
elif key == "section_context.language":
return mock_search_context["language"]
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect
expected_result = {
"datas": [{"id": "test_id", "content": "Test content"}]}
expected_add_source_result = {
"modified_report": "Test modified report",
"datas": [{"id": "test_id", "content": "Test content"}]}
with patch.object(SourceTracer, '__init__', return_value=None) as mock_init:
with patch.object(SourceTracer, 'research_trace_source', new_callable=AsyncMock) as mock_research:
mock_research.return_value = expected_result
with patch.object(SourceTracer, 'add_source_to_report') as mock_add_source:
mock_add_source.return_value = expected_add_source_result
with patch.object(sub_source_tracer_node, 'post_handle') as mock_post_handle:
mock_post_handle.return_value = {
"next_node": NodeId.END.value}
result = await sub_source_tracer_node.do_invoke(None, mock_session, None)
mock_init.assert_called_once()
mock_research.assert_called_once()
mock_add_source.assert_called_once()
mock_post_handle.assert_called_once()
args, kwargs = mock_post_handle.call_args
algorithm_output = args[1]
assert algorithm_output["trace_source_datas"] == expected_add_source_result["datas"]
assert algorithm_output["modified_report"] == expected_add_source_result["modified_report"]
@pytest.mark.asyncio
async def test_do_invoke_with_trace_disabled(self, sub_source_tracer_node, mock_session, mock_search_context):
"""Test _do_invoke method with trace source disabled."""
sub_report_content_obj = SubReportContent(
sub_report_content_text=mock_search_context["sub_report_content"],
classified_content=[]
)
def get_global_state_side_effect(key):
if key == "config.source_tracer_research_trace_source_switch":
return False
elif key == "section_context.sub_report_content":
return sub_report_content_obj
elif key == "section_context.search_record":
return mock_search_context["search_record"]
elif key == "section_context.language":
return mock_search_context["language"]
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect
with patch.object(sub_source_tracer_node, 'skip_trace_source_handle') as mock_skip:
mock_skip.return_value = {"next_node": NodeId.END.value}
result = await sub_source_tracer_node.do_invoke(None, mock_session, None)
mock_skip.assert_called_once()
@pytest.mark.asyncio
async def test_do_invoke_skips_generated_citations_when_switch_disabled(
self, sub_source_tracer_node, mock_session, mock_search_context):
"""Test _do_invoke skips new citation generation when the fine-grained switch is disabled."""
sub_report_content_obj = SubReportContent(
sub_report_content_text=mock_search_context["sub_report_content"],
classified_content=[]
)
def get_global_state_side_effect(key):
if key == "config.source_tracer_research_trace_source_switch":
return True
elif key == "config.source_tracer_generated_citation_switch":
return False
elif key == "section_context.sub_report_content":
return sub_report_content_obj
elif key == "section_context.search_record":
return mock_search_context["search_record"]
elif key == "section_context.language":
return mock_search_context["language"]
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect
expected_add_source_result = {
"modified_report": "Test modified report",
"datas": [{"id": "test_id", "content": "Test content"}]}
with patch.object(SourceTracer, '__init__', return_value=None) as mock_init:
with patch.object(SourceTracer, 'research_trace_source', new_callable=AsyncMock) as mock_research:
with patch.object(SourceTracer, 'add_source_to_report') as mock_add_source:
mock_add_source.return_value = expected_add_source_result
with patch.object(sub_source_tracer_node, 'post_handle') as mock_post_handle:
mock_post_handle.return_value = {
"next_node": NodeId.END.value}
result = await sub_source_tracer_node.do_invoke(None, mock_session, None)
mock_init.assert_called_once()
mock_research.assert_not_called()
mock_add_source.assert_called_once()
mock_post_handle.assert_called_once()
args, kwargs = mock_post_handle.call_args
algorithm_output = args[1]
assert algorithm_output["trace_source_datas"] == expected_add_source_result["datas"]
assert algorithm_output["modified_report"] == expected_add_source_result["modified_report"]
assert result == {"next_node": NodeId.END.value}
@staticmethod
def test_post_handle(sub_source_tracer_node, mock_session):
"""Test post_handle method with different scenarios."""
existing_sub_report = SubReportContent(
sub_report_content_text="Original content",
classified_content=[]
)
def get_global_state_side_effect(key):
if key == "section_context.sub_report_content":
return existing_sub_report
elif key == "section_context.section_idx":
return 1
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect
algorithm_output = {
"trace_source_datas": [
{"id": "test_id", "content": "Test content",
"url": "https://example.com"}
],
"modified_report": "Test modified report"
}
result = sub_source_tracer_node.post_handle(
None, algorithm_output, mock_session, None)
assert result["next_node"] == NodeId.END.value
mock_session.update_global_state.assert_called_once()
call_args = mock_session.update_global_state.call_args[0][0]
assert "section_context.sub_report_content" in call_args
updated_sub_report = call_args["section_context.sub_report_content"]
assert isinstance(updated_sub_report, SubReportContent)
assert updated_sub_report.sub_report_content_text == "Test modified report"
assert updated_sub_report.sub_report_trace_source_datas == algorithm_output["trace_source_datas"]
mock_session.reset_mock()
existing_sub_report = SubReportContent(
sub_report_content_text="Original content",
classified_content=[]
)
mock_session.get_global_state.side_effect = get_global_state_side_effect
algorithm_output = {"trace_source_datas": [], "modified_report": ""}
result = sub_source_tracer_node.post_handle(
None, algorithm_output, mock_session, None)
assert result["next_node"] == NodeId.END.value
mock_session.update_global_state.assert_called_once()
call_args = mock_session.update_global_state.call_args[0][0]
assert "section_context.sub_report_content" in call_args
updated_sub_report = call_args["section_context.sub_report_content"]
assert isinstance(updated_sub_report, SubReportContent)
assert updated_sub_report.sub_report_content_text == ""
assert updated_sub_report.sub_report_trace_source_datas == []