import json
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes import SourceTracerNode
from openjiuwen_deepsearch.framework.openjiuwen.agent.search_context import Report
from openjiuwen_deepsearch.utils.constants_utils.node_constants import NodeId
MODULE_PATH = "openjiuwen_deepsearch.framework.openjiuwen.agent.main_graph_nodes"
class ExposedSourceTracerNode(SourceTracerNode):
"""用于测试的类,公开受保护的方法以遵循 G.CLS.11 规则"""
def _pre_handle(self, *args, **kwargs):
return self.pre_handle(*args, **kwargs)
def pre_handle(self, *args, **kwargs):
return super()._pre_handle(*args, **kwargs)
def _skip_trace_source_handle(self, *args, **kwargs):
return self.skip_trace_source_handle(*args, **kwargs)
def skip_trace_source_handle(self, *args, **kwargs):
return super()._skip_trace_source_handle(*args, **kwargs)
async def _do_invoke(self, *args, **kwargs):
return await self.do_invoke(*args, **kwargs)
async def do_invoke(self, *args, **kwargs):
return await super()._do_invoke(*args, **kwargs)
def _post_handle(self, *args, **kwargs):
return self.post_handle(*args, **kwargs)
def post_handle(self, *args, **kwargs):
return super()._post_handle(*args, **kwargs)
class TestSourceTracerNode:
"""Test cases for SourceTracerNode class."""
@pytest.fixture
def source_tracer_node(self):
"""Fixture to create a SourceTracerNode instance."""
return ExposedSourceTracerNode()
@pytest.fixture
def mock_session(self):
"""Fixture to create a mock Session instance."""
session = MagicMock()
session.get_global_state = MagicMock()
session.update_global_state = MagicMock()
return session
@pytest.fixture
def mock_search_context(self):
"""Fixture to provide mock search context data."""
return {
"search_mode": "research",
"current_report": Report(
id="test_report_id",
report_content="This is a test report with some content."
),
"search_record": {
"web_page_search_record": [
{"title": "example", "url": "https://example.com", "content": "test content"}],
"web_image_search_record": [],
"local_text_search_record": [],
"local_image_search_record": []
},
"merged_trace_source_datas": [],
"all_classified_contents": [],
"language": "zh-CN"
}
@pytest.fixture
def mock_config(self):
"""Fixture to provide mock config data."""
return {
"source_tracer_research_trace_source_switch": True
}
@staticmethod
def test_pre_handle(source_tracer_node, mock_session, mock_search_context, mock_config):
"""Test _pre_handle method with different configurations."""
def get_global_state_side_effect_enabled(key):
if key == "search_context.search_mode":
return mock_search_context["search_mode"]
elif key == "search_context.current_report":
return mock_search_context["current_report"]
elif key == "config.source_tracer_research_trace_source_switch":
return mock_config["source_tracer_research_trace_source_switch"]
elif key in ["search_context.search_record", "search_context.merged_trace_source_datas",
"search_context.all_classified_contents", "search_context.language"]:
return mock_search_context.get(key.replace("search_context.", ""))
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect_enabled
result = source_tracer_node.pre_handle(None, mock_session, None)
assert result["need_exit"] is False
assert result["search_mode"] == "research"
assert result["research_trace_source_switch"] is True
def get_global_state_side_effect_disabled(key):
if key == "config.source_tracer_research_trace_source_switch":
return False
return get_global_state_side_effect_enabled(key)
mock_session.get_global_state.side_effect = get_global_state_side_effect_disabled
result = source_tracer_node.pre_handle(None, mock_session, None)
assert result["need_exit"] is True
@staticmethod
def test_post_handle_routes_research_need_exit_to_source_tracer_infer(source_tracer_node, mock_session):
algorithm_output = {
"need_exit": True,
"origin_report": "test report",
"search_mode": "research",
}
mock_session.get_global_state.return_value = None
with patch(f"{MODULE_PATH}.add_debug_log_wrapper"):
result = source_tracer_node.post_handle(None, algorithm_output, mock_session, None)
assert result["next_node"] == NodeId.SOURCE_TRACER_INFER.value
@pytest.mark.asyncio
async def test_build_citation_checker_result(self, source_tracer_node):
"""Test build_citation_checker_result static method."""
citation_checker_info = {
"need_check": True,
"response_content": {
"language": "zh-CN",
"article": "Article with references"
}
}
datas = [{"id": "test_id", "content": "Test content"}]
expected_citation_result = json.dumps({
"response_content": "Processed report",
"citation_messages": {"test_id": "Valid"}
}, ensure_ascii=False)
with patch(f'{MODULE_PATH}.postprocess_by_citation_checker',
new_callable=AsyncMock) as mock_postprocess:
mock_postprocess.return_value = expected_citation_result
result = await SourceTracerNode.build_citation_checker_result(citation_checker_info, datas, "mock_model")
mock_postprocess.assert_called_once_with(
citation_checker_info["response_content"], datas, "mock_model")
assert result["check_result"] is True
assert result["citation_checker_result_str"] == expected_citation_result
@pytest.mark.asyncio
async def test_do_invoke(self, source_tracer_node, mock_session, mock_search_context, mock_config):
"""Test _do_invoke method with different scenarios."""
def get_global_state_side_effect(key):
if key == "search_context":
return mock_search_context
elif key == "config":
return mock_config
return None
mock_session.get_global_state.side_effect = get_global_state_side_effect
mock_source_tracer_result = ("Report with citations", [
{"id": "test_id", "content": "Test content"}])
mock_citation_checker_info = {
"need_check": True, "response_content": {"article": "Test article"}}
mock_citation_checker_result = {
"check_result": True, "citation_checker_result_str": '{"response_content": "Valid"}'}
with patch.object(source_tracer_node, 'pre_handle') as mock_pre_handle:
with patch(f'{MODULE_PATH}.preprocess_info') as mock_preprocess:
with patch.object(SourceTracerNode, 'build_citation_checker_result',
new_callable=AsyncMock) as mock_build_citation:
with patch.object(source_tracer_node, 'post_handle') as mock_post_handle:
mock_pre_handle.return_value = {
"need_exit": False, "report": "test", "language": "zh-CN"}
mock_preprocess.return_value = mock_citation_checker_info
mock_build_citation.return_value = mock_citation_checker_result
mock_post_handle.return_value = {
"next_node": "END"}
await source_tracer_node.do_invoke(None, mock_session, None)
mock_build_citation.assert_called_once()
mock_pre_handle.return_value = {"need_exit": True}
with patch.object(source_tracer_node, 'skip_trace_source_handle') as mock_skip:
mock_skip.return_value = {"next_node": "END"}
await source_tracer_node.do_invoke(None, mock_session, None)
mock_skip.assert_called_once()
mock_pre_handle.return_value = {
"need_exit": False, "report": "test"}
mock_preprocess.side_effect = Exception("Test exception")
await source_tracer_node.do_invoke(None, mock_session, None)
args, kwargs = mock_post_handle.call_args
algorithm_output = args[1]
assert algorithm_output["check_result_dict"]["check_result"] is False