import base64
import json
from unittest.mock import patch, AsyncMock
import pytest
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer import SourceTracerInfer
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_call_model import (
GraphInfo,
)
class TestSourceTracerInfer:
"""Test cases for SourceTracerInfer core functionality."""
def setup_method(self):
"""Set up test fixtures."""
self.context = {
"language": "zh-CN",
"llm_model_name": "mock_model",
"source_tracer_response": "test response",
"conclusion_with_records": None,
}
self.source_tracer_infer = SourceTracerInfer(self.context)
def test_init(self):
"""Test SourceTracerInfer initialization."""
assert self.source_tracer_infer.context == self.context
assert self.source_tracer_infer.language == "zh-CN"
assert self.source_tracer_infer.model_name == "mock_model"
assert self.source_tracer_infer.response == "test response"
assert self.source_tracer_infer.conclusion_with_records is None
assert self.source_tracer_infer.checker_infos == {
"graph_infos": [],
"search_records": [],
}
assert hasattr(self.source_tracer_infer, "node_number")
assert hasattr(self.source_tracer_infer, "supplement_graph")
def test_encode_html_to_base64_valid(self):
"""Test _encode_html_to_base64 with valid HTML."""
html_content = "<html><body>test</body></html>"
result = self.source_tracer_infer._encode_html_to_base64(html_content)
decoded = base64.b64decode(result).decode("utf-8")
assert decoded == html_content
def test_encode_html_to_base64_invalid(self):
"""Test _encode_html_to_base64 with invalid encoding."""
with patch("base64.b64encode", side_effect=Exception("Encoding error")):
with pytest.raises(Exception):
self.source_tracer_infer._encode_html_to_base64("test")
@pytest.mark.asyncio
async def test_get_conclusion_and_records_new(self):
"""Test get_conclusion_and_records when conclusion_with_records is None."""
expected_result = [{"conclusion": "test", "search_records": []}]
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.ResearchInferPreprocess"
) as mock_preprocessor:
mock_instance = AsyncMock()
mock_instance.run.return_value = expected_result
mock_preprocessor.return_value = mock_instance
await self.source_tracer_infer.get_conclusion_and_records()
assert self.source_tracer_infer.conclusion_with_records == expected_result
@pytest.mark.asyncio
async def test_extract_reference_no_results(self):
"""Test extract_reference when no valid references found."""
datas = {
"conclusion": ["test conclusion"],
"search_records": [{"content": "test"}],
}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = []
result = await self.source_tracer_infer.extract_reference(datas)
assert result == {}
@pytest.mark.asyncio
async def test_extract_reference_valid(self):
"""Test extract_reference with valid data."""
datas = {
"conclusion": ["test conclusion"],
"search_records": [
{"content": "test content 1"},
{"content": "test content 2"},
],
}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = [0, 1]
result = await self.source_tracer_infer.extract_reference(datas)
assert result["conclusion"] == "test conclusion"
assert len(result["reference"]) == 2
assert result["reference"][0]["id"] == 0
assert result["reference"][0]["content"] == "test content 1"
@pytest.mark.asyncio
async def test_extract_reference_invalid_index(self):
"""Test extract_reference with invalid index in results."""
datas = {
"conclusion": ["test conclusion"],
"search_records": [{"content": "test content"}],
}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = [5]
result = await self.source_tracer_infer.extract_reference(datas)
assert result["conclusion"] == "test conclusion"
assert len(result["reference"]) == 0
@pytest.mark.asyncio
async def test_infer_basic(self):
"""Test infer with basic data."""
evidences = {
"conclusion": "test conclusion",
"reference": [{"content": "test"}],
}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = ["test inference"]
result = await self.source_tracer_infer.infer(evidences)
assert result["conclusion"] == "test conclusion"
assert result["inference"] == "test inference"
@pytest.mark.asyncio
async def test_infer_empty_result(self):
"""Test infer with empty LLM result."""
evidences = {"conclusion": "test conclusion", "reference": []}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = []
result = await self.source_tracer_infer.infer(evidences)
assert result["conclusion"] == "test conclusion"
assert result["inference"] == ""
@pytest.mark.asyncio
async def test_filter_invalid_infer_valid(self):
"""Test filter_invalid_infer with valid inference."""
inferences = {"inference": "valid inference"}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = "true"
result = await self.source_tracer_infer.filter_invalid_infer(inferences)
assert result == inferences
@pytest.mark.asyncio
async def test_filter_invalid_infer_empty_result(self):
"""Test filter_invalid_infer with empty/falsy result raises ValueError."""
inferences = {"inference": "invalid inference"}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = ""
with pytest.raises(ValueError) as exc_info:
await self.source_tracer_infer.filter_invalid_infer(inferences)
assert "invalid inference" in str(exc_info.value)
@pytest.mark.asyncio
async def test_filter_invalid_infer_false_result(self):
"""Test filter_invalid_infer with 'false' string result."""
inferences = {"inference": "invalid inference"}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = "false"
result = await self.source_tracer_infer.filter_invalid_infer(inferences)
assert result == inferences
@pytest.mark.asyncio
async def test_structured_infer_valid(self):
"""Test structured_infer with valid inference."""
inference = {"inference": "test"}
expected_result = [([0], "relation", 1)]
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = expected_result
result = await self.source_tracer_infer.structured_infer(inference)
assert result == expected_result
@pytest.mark.asyncio
async def test_structured_infer_empty_result(self):
"""Test structured_infer with empty result."""
inference = {"inference": "test"}
with patch(
"openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
new_callable=AsyncMock,
) as mock_call_model:
mock_call_model.return_value = []
with pytest.raises(ValueError) as exc_info:
await self.source_tracer_infer.structured_infer(inference)
assert "unstructured inference" in str(exc_info.value)
def test_generate_html_basic(self):
"""Test generate_html with basic graph data."""
checked_infer_graph = GraphInfo(
[
[[0], "relation", 1]
],
{
0: {"label": "node0", "url": "https://www.example.com"},
1: {"label": "node1"},
},
[0],
[1],
)
result = self.source_tracer_infer.generate_html.run(checked_infer_graph)
assert isinstance(result, str)
assert "<html>" in result
assert "</html>" in result
def test_generate_html_multiple_heads(self):
"""Test generate_html with multiple head nodes."""
checked_infer_graph = GraphInfo(
[[[0, 1], "relation", 2]],
{0: {"label": "node0"}, 1: {"label": "node1"}, 2: {"label": "node2"}},
[0],
[2],
)
result = self.source_tracer_infer.generate_html.run(checked_infer_graph)
assert isinstance(result, str)
assert "<html>" in result
def test_mark_conclusion_in_report_basic(self):
"""Test mark_conclusion_in_report with basic data."""
infer_messages = [{"id": 0, "conclusion": "test conclusion"}]
conclusion_infos = [{"start_pos": 0, "end_pos": 5}]
self.source_tracer_infer.response = "original"
result = self.source_tracer_infer.mark_conclusion_in_report(
infer_messages, conclusion_infos
)
assert "[test conclusion](#inference:0)" in result
def test_mark_conclusion_in_report_multiple(self):
"""Test mark_conclusion_in_report with multiple conclusions."""
infer_messages = [
{"id": 0, "conclusion": "conclusion1"},
{"id": 1, "conclusion": "conclusion2"},
]
conclusion_infos = [
{"start_pos": 0, "end_pos": 5},
{"start_pos": 10, "end_pos": 15},
]
self.source_tracer_infer.response = "original text here"
result = self.source_tracer_infer.mark_conclusion_in_report(
infer_messages, conclusion_infos
)
assert isinstance(result, str)
assert len(result) > 0
def test_mark_conclusion_in_report_error(self):
"""Test mark_conclusion_in_report error handling."""
infer_messages = [{"id": 0, "conclusion": "test"}]
conclusion_infos = [{"start_pos": 0, "end_pos": 100}]
original_response = "original"
self.source_tracer_infer.response = original_response
result = self.source_tracer_infer.mark_conclusion_in_report(
infer_messages, conclusion_infos
)
assert result != original_response
@pytest.mark.asyncio
async def test_async_run_success(self):
"""Test async_run with successful execution."""
datas = {"conclusion": ["test"], "search_records": [{"content": "test"}]}
with patch.object(
self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = {"conclusion": "test", "reference": []}
with patch.object(
self.source_tracer_infer, "infer", new_callable=AsyncMock
) as mock_infer:
mock_infer.return_value = {
"conclusion": "test",
"inference": "inference",
}
with patch.object(
self.source_tracer_infer,
"filter_invalid_infer",
new_callable=AsyncMock,
) as mock_filter:
mock_filter.return_value = {
"conclusion": "test",
"inference": "inference",
}
with patch.object(
self.source_tracer_infer,
"structured_infer",
new_callable=AsyncMock,
) as mock_structured:
mock_structured.return_value = [([0], "relation", 1)]
with patch.object(
self.source_tracer_infer.node_number, "number_node"
) as mock_number:
mock_number.return_value = (
[([0], "relation", 1)],
{0: {"label": "test"}},
[0],
[1],
)
with patch.object(
self.source_tracer_infer.supplement_graph,
"run",
new_callable=AsyncMock,
) as mock_supplement:
mock_supplement.return_value = (
[([0], "relation", 1)],
{0: {"label": "test"}},
[0],
[1],
)
with patch.object(
self.source_tracer_infer.generate_html, "run"
) as mock_html:
mock_html.return_value = "<html>test</html>"
(
infer_message,
checked_graph,
) = await self.source_tracer_infer.async_run(datas)
assert infer_message["conclusion"] == "test"
assert infer_message["inference"] == "inference"
assert "html_base64" in infer_message
assert checked_graph is not None
@pytest.mark.asyncio
async def test_async_run_extract_reference_empty(self):
"""Test async_run when extract_reference returns empty."""
datas = {"conclusion": [], "search_records": []}
with patch.object(
self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
) as mock_extract:
mock_extract.return_value = {}
infer_message, checked_graph = await self.source_tracer_infer.async_run(
datas
)
assert infer_message == {}
assert checked_graph is None
@pytest.mark.asyncio
async def test_async_run_error_handling(self):
"""Test async_run error handling."""
datas = {"conclusion": ["test"], "search_records": [{"content": "test"}]}
with patch.object(
self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
) as mock_extract:
mock_extract.side_effect = Exception("Test error")
infer_message, checked_graph = await self.source_tracer_infer.async_run(
datas
)
assert infer_message == {}
assert checked_graph is None
@pytest.mark.asyncio
async def test_run_success(self):
"""Test run with successful execution."""
self.source_tracer_infer.conclusion_with_records = [
{
"conclusion": ["test"],
"search_records": [{"content": "test"}],
"start_pos": 0,
"end_pos": 5,
}
]
with patch.object(
self.source_tracer_infer, "async_run", new_callable=AsyncMock
) as mock_async_run:
mock_async_run.return_value = (
{
"id": 0,
"conclusion": "test",
"inference": "inference",
"html_base64": "base64",
},
([([0], "relation", 1)], {0: {"label": "test"}}, [0], [1]),
)
with patch.object(
self.source_tracer_infer, "mark_conclusion_in_report"
) as mock_mark:
mock_mark.return_value = "marked response"
(
response,
infer_messages,
checker_infos,
error,
) = await self.source_tracer_infer.run()
assert response == "marked response"
assert len(infer_messages) == 1
assert infer_messages[0]["id"] == 0
assert len(checker_infos["graph_infos"]) == 1
assert len(checker_infos["search_records"]) == 1
assert error is None
@pytest.mark.asyncio
async def test_run_no_conclusion_with_records(self):
"""Test run when conclusion_with_records is None."""
with patch.object(
self.source_tracer_infer,
"get_conclusion_and_records",
new_callable=AsyncMock,
) as mock_get:
mock_get.return_value = None
(
response,
infer_messages,
checker_infos,
error,
) = await self.source_tracer_infer.run()
assert response == "test response"
assert infer_messages == []
assert checker_infos["graph_infos"] == []
assert checker_infos["search_records"] == []
assert error is not None
@pytest.mark.asyncio
async def test_run_error_handling(self):
"""Test run error handling."""
self.source_tracer_infer.conclusion_with_records = [
{"conclusion": ["test"], "search_records": [{"content": "test"}]}
]
with patch.object(
self.source_tracer_infer, "async_run", new_callable=AsyncMock
) as mock_async_run:
mock_async_run.side_effect = Exception("Test error")
(
response,
infer_messages,
checker_infos,
error,
) = await self.source_tracer_infer.run()
assert response == "test response"
assert infer_messages == []
assert checker_infos["graph_infos"] == []
assert checker_infos["search_records"] == []
assert error == "Test error"
@pytest.mark.asyncio
async def test_run_empty_html_base64(self):
"""Test run filters out messages with empty html_base64."""
self.source_tracer_infer.conclusion_with_records = [
{"conclusion": ["test"], "search_records": [{"content": "test"}]}
]
with patch.object(
self.source_tracer_infer, "async_run", new_callable=AsyncMock
) as mock_async_run:
mock_async_run.return_value = (
{"id": 0, "html_base64": ""},
None,
)
(
response,
infer_messages,
checker_infos,
error,
) = await self.source_tracer_infer.run()
assert len(infer_messages) == 0
assert len(checker_infos["graph_infos"]) == 0