import logging
from openjiuwen_deepsearch.algorithm.source_trace.add_source import (
SourceReferenceProcessor,
add_source_references,
generate_source_datas,
merge_source_datas,
insert_source_info,
extract_source_item_info,
_remove_md_references_from_chunk,
_merge_source_infos,
_filter_valid_references,
remove_trailing_spaces_and_punctuation,
_escape_html_special_chars
)
class TestSourceReferenceProcessor:
"""Test cases for SourceReferenceProcessor class."""
def test_init(self):
"""Test initialization of SourceReferenceProcessor."""
preprocessed_report = "这是一个示例报告。"
search_record = {"web_search": [{"title": "测试标题", "url": "http://test.com"}]}
processor = SourceReferenceProcessor(preprocessed_report, search_record)
assert processor.preprocessed_report == preprocessed_report
assert processor.search_record == search_record
assert processor.all_data_items == []
def test_extract_source_info(self):
"""Test extracting source info from trace result."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
search_record = {
"web_search": [
{"title": "AI研究", "url": "http://ai.com", "content": "AI很有趣"},
{"title": "机器学习基础", "url": "http://ml.com", "content": "机器学习很重要"}
]
}
processor = SourceReferenceProcessor(preprocessed_report, search_record)
trace_result = {
"sentence": "这是一个关于人工智能的测试句子。",
"matched_source_indices": [0],
"source": "web_search"
}
source_info, data_items = processor.extract_source_info(trace_result)
assert source_info != ""
assert len(data_items) == 1
assert data_items[0]["title"] == "AI研究"
assert data_items[0]["url"] == "http://ai.com"
def test_extract_source_info_invalid_trace_result(self):
"""Test extracting source info with invalid trace result."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
search_record = {
"web_search": [
{"title": "AI研究", "url": "http://ai.com", "content": "AI很有趣"}
]
}
processor = SourceReferenceProcessor(preprocessed_report, search_record)
trace_result = {
"sentence": "这个句子不在报告中。",
"matched_source_indices": [0],
"source": "web_search"
}
source_info, data_items = processor.extract_source_info(trace_result)
assert source_info == ""
assert data_items == []
def test_validate_trace_result(self):
"""Test validation of trace result."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
search_record = {
"web_search": [
{"title": "AI研究", "url": "http://ai.com", "content": "AI很有趣"}
]
}
processor = SourceReferenceProcessor(preprocessed_report, search_record)
valid_result = processor._validate_trace_result(
"这是一个关于人工智能的测试句子。",
[0],
"web_search"
)
assert valid_result is True
invalid_result = processor._validate_trace_result(
"这个句子不在报告中。",
[0],
"web_search"
)
assert invalid_result is False
invalid_result2 = processor._validate_trace_result(
"这是一个关于人工智能的测试句子。",
[0],
"nonexistent_source"
)
assert invalid_result2 is False
class TestAddSourceReferences:
"""Test cases for add_source_references function."""
def test_add_source_references_basic(self):
"""Test basic functionality of adding source references."""
preprocessed_report = "这是一个测试句子。这是另一个句子。"
source_references = [
{
"chunk": "这是一个测试句子",
"title": "测试来源",
"url": "http://test.com"
}
]
modified_report, updated_references = add_source_references(
preprocessed_report, source_references
)
assert "[source_tracer_result][测试来源](http://test.com)" in modified_report
assert len(updated_references) == 1
def test_add_source_references_empty_input(self):
"""Test adding source references with empty input."""
preprocessed_report = "这是一个测试句子。"
source_references = []
modified_report, updated_references = add_source_references(
preprocessed_report, source_references
)
assert modified_report == preprocessed_report
assert updated_references == []
def test_add_source_references_origin_data(self):
"""Test handling of origin data."""
preprocessed_report = "这是一个测试句子。"
source_references = [
{
"_is_origin_data": True,
"chunk": "这是一个测试句子",
"title": "测试来源"
}
]
modified_report, updated_references = add_source_references(
preprocessed_report, source_references
)
assert len(updated_references) == 1
assert updated_references[0]["_is_origin_data"] is True
def test_add_source_references_sentence_not_found(self):
"""Test handling of sentences not found in report."""
preprocessed_report = "这是一个测试句子。"
source_references = [
{
"chunk": "这个句子不在报告中",
"title": "测试来源",
"url": "http://test.com"
}
]
modified_report, updated_references = add_source_references(
preprocessed_report, source_references
)
assert modified_report == preprocessed_report
assert len(updated_references) == 0
def test_add_source_references_multiple_references_same_sentence(self):
"""Test adding multiple references to the same sentence."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
source_references = [
{
"chunk": "这是一个关于人工智能的测试句子",
"title": "AI来源1",
"url": "http://ai1.com"
},
{
"chunk": "这是一个关于人工智能的测试句子",
"title": "AI来源2",
"url": "http://ai2.com"
}
]
modified_report, updated_references = add_source_references(
preprocessed_report, source_references
)
assert "[AI来源1](http://ai1.com)" in modified_report
assert "[AI来源2](http://ai2.com)" in modified_report
assert len(updated_references) == 2
class TestRemoveMdReferencesFromChunk:
"""Test cases for _remove_md_references_from_chunk function."""
def test_remove_md_references_from_chunk_basic(self):
"""Test basic functionality of removing MD references from chunk."""
data_item = {
"chunk": "这是一个测试句子 [source_tracer_result][测试](http://test.com)"
}
_remove_md_references_from_chunk(data_item)
assert data_item["chunk"] == "这是一个测试句子"
def test_remove_md_references_from_chunk_multiple_refs(self):
"""Test removing multiple MD references from chunk."""
data_item = {
"chunk": "这是一个测试句子 [source_tracer_result][测试1](http://test1.com) 和 [测试2](http://test2.com)"
}
_remove_md_references_from_chunk(data_item)
assert data_item["chunk"] == "这是一个测试句子 和"
def test_remove_md_references_from_chunk_no_refs(self):
"""Test handling chunk without MD references."""
data_item = {
"chunk": "这是一个没有引用的测试句子。"
}
_remove_md_references_from_chunk(data_item)
assert data_item["chunk"] == "这是一个没有引用的测试句子。"
def test_remove_md_references_from_chunk_no_chunk_field(self):
"""Test handling data item without chunk field."""
data_item = {
"title": "测试标题"
}
_remove_md_references_from_chunk(data_item)
assert "title" in data_item
assert "chunk" not in data_item
def test_remove_md_references_from_chunk_non_string_chunk(self):
"""Test handling non-string chunk field."""
data_item = {
"chunk": 123
}
_remove_md_references_from_chunk(data_item)
assert data_item["chunk"] == 123
class TestMergeSourceInfos:
"""Test cases for _merge_source_infos function."""
def test_merge_source_infos_basic(self):
"""Test basic functionality of merging source infos."""
ref_infos = [
{"title": "Source 1", "url": "http://source1.com"},
{"title": "Source 2", "url": "http://source2.com"}
]
result = _merge_source_infos(ref_infos)
assert "[source_tracer_result][Source 1](http://source1.com)" in result
assert "[source_tracer_result][Source 2](http://source2.com)" in result
def test_merge_source_infos_with_title_only(self):
"""Test merging source infos with title only."""
ref_infos = [
{"title": "Source 1", "url": ""},
{"title": "Source 2", "url": "http://source2.com"}
]
result = _merge_source_infos(ref_infos)
assert "[source_tracer_result][Source 1](Source 1)" in result
assert "[source_tracer_result][Source 2](http://source2.com)" in result
def test_merge_source_infos_no_title(self):
"""Test handling source info without title."""
ref_infos = [
{"title": "", "url": "http://source1.com"},
{"title": "Source 2", "url": "http://source2.com"}
]
result = _merge_source_infos(ref_infos)
assert "[source_tracer_result][Source 2](http://source2.com)" in result
assert "source1.com" not in result
def test_merge_source_infos_empty_list(self):
"""Test merging empty source infos list."""
ref_infos = []
result = _merge_source_infos(ref_infos)
assert result == ""
class TestGenerateSourceDatas:
"""Test cases for generate_source_datas function."""
def test_generate_source_datas_basic(self):
"""Test basic functionality of generating source datas."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
search_record = {
"web_search": [
{"title": "AI研究", "url": "http://ai.com", "content": "AI很有趣"}
]
}
trace_results = [
{
"sentence": "这是一个关于人工智能的测试句子",
"matched_source_indices": [0],
"source": "web_search"
}
]
result = generate_source_datas(preprocessed_report, search_record, trace_results)
assert len(result) == 1
assert result[0]["title"] == "AI研究"
assert result[0]["url"] == "http://ai.com"
def test_generate_source_datas_sentence_not_found(self):
"""Test handling sentences not found in report."""
preprocessed_report = "这是一个关于人工智能的测试句子。"
search_record = {
"web_search": [
{"title": "AI研究", "url": "http://ai.com", "content": "AI很有趣"}
]
}
trace_results = [
{
"sentence": "这个句子不在报告中",
"matched_source_indices": [0],
"source": "web_search"
}
]
result = generate_source_datas(preprocessed_report, search_record, trace_results)
assert len(result) == 0
class TestFilterValidReferences:
"""Test cases for _filter_valid_references function."""
def test_filter_valid_references_basic(self):
"""Test basic functionality of filtering valid references."""
report = "这是一个测试句子。这是另一个句子。"
references = [
{"chunk": "这是一个测试句子", "title": "测试来源"},
{"chunk": "这个不在报告中", "title": "无效来源"}
]
result = _filter_valid_references(report, references)
assert len(result) == 1
assert result[0]["chunk"] == "这是一个测试句子"
assert result[0]["title"] == "测试来源"
def test_filter_valid_references_origin_data(self):
"""Test handling origin data that may not be found in report."""
report = "这是一个测试句子。"
references = [
{"chunk": "这个不在报告中", "title": "无效来源", "_is_origin_data": True}
]
result = _filter_valid_references(report, references)
assert len(result) == 1
assert result[0]["_is_origin_data"] is True
def test_filter_valid_references_no_chunk(self):
"""Test handling references without chunk field."""
report = "这是一个测试句子。"
references = [
{"title": "没有chunk字段的测试来源"}
]
result = _filter_valid_references(report, references)
assert len(result) == 0
class TestMergeSourceDatas:
"""Test cases for merge_source_datas function."""
def test_merge_source_datas_basic(self):
"""Test basic functionality of merging source datas."""
report = "这是一个测试句子。"
datas = [
{"chunk": "这是一个测试句子", "title": "生成的来源"}
]
origin_datas = [
{"chunk": "这是一个测试句子", "title": "原始来源", "_is_origin_data": True}
]
result = merge_source_datas(report, datas, origin_datas)
assert len(result) == 2
titles = [item["title"] for item in result]
assert "生成的来源" in titles
assert "原始来源" in titles
def test_merge_source_datas_empty_inputs(self):
"""Test merging with empty inputs."""
report = "这是一个测试句子。"
result = merge_source_datas(report, [], [])
assert result == []
def test_merge_source_datas_one_empty(self):
"""Test merging with one empty input."""
report = "这是一个测试句子。"
datas = [
{"chunk": "这是一个测试句子", "title": "生成的来源"}
]
result = merge_source_datas(report, datas, [])
assert len(result) == 1
assert result[0]["title"] == "生成的来源"
class TestInsertSourceInfo:
"""Test cases for insert_source_info function."""
def test_insert_source_info_basic(self):
"""Test basic functionality of inserting source info."""
report = "这是一个测试句子。这是另一个句子。"
sentence = "这是一个测试句子"
source_info = "[source_tracer_result][测试](http://test.com)"
success, modified_report = insert_source_info(report, sentence, source_info)
assert success is True
assert "[source_tracer_result][测试](http://test.com)" in modified_report
def test_insert_source_info_sentence_not_found(self):
"""Test handling sentence not found in report."""
report = "这是一个测试句子。"
sentence = "这个句子不在报告中"
source_info = "[source_tracer_result][测试](http://test.com)"
success, modified_report = insert_source_info(report, sentence, source_info)
assert success is False
assert modified_report == report
def test_insert_source_info_empty_inputs(self):
"""Test handling empty inputs."""
report = ""
sentence = "这是一个测试句子"
source_info = "[source_tracer_result][测试](http://test.com)"
success, modified_report = insert_source_info(report, sentence, source_info)
assert success is False
assert modified_report == ""
success2, modified_report2 = insert_source_info("这是一个测试句子。", "", source_info)
assert success2 is False
assert modified_report2 == "这是一个测试句子。"
class TestExtractSourceItemInfo:
"""Test cases for extract_source_item_info function."""
def test_extract_source_item_info_basic(self):
"""Test basic functionality of extracting source item info."""
source_list = [
{"title": "测试标题", "url": "http://test.com", "content": "测试内容"}
]
index = 0
sentence = "这是一个测试句子。"
source_info, data = extract_source_item_info(source_list, index, sentence)
assert source_info != ""
assert data["title"] == "测试标题"
assert data["url"] == "http://test.com"
assert data["content"] == "测试内容"
assert data["chunk"] == "这是一个测试句子。"
def test_extract_source_item_info_index_out_of_range(self):
"""Test handling index out of range."""
source_list = [
{"title": "测试标题", "url": "http://test.com", "content": "测试内容"}
]
index = 5
sentence = "这是一个测试句子。"
source_info, data = extract_source_item_info(source_list, index, sentence)
assert source_info == ""
assert data == {}
def test_extract_source_item_info_missing_title(self):
"""Test handling source item without title."""
source_list = [
{"url": "http://test.com", "content": "测试内容"}
]
index = 0
sentence = "这是一个测试句子。"
source_info, data = extract_source_item_info(source_list, index, sentence)
assert source_info == ""
assert data == {}
def test_extract_source_item_info_missing_content(self):
"""Test handling source item without content."""
source_list = [
{"title": "测试标题", "url": "http://test.com"}
]
index = 0
sentence = "这是一个测试句子。"
source_info, data = extract_source_item_info(source_list, index, sentence)
assert source_info == ""
assert data == {}
def test_extract_source_item_info_no_url(self):
"""Test handling source item without URL."""
source_list = [
{"title": "测试标题", "content": "测试内容"}
]
index = 0
sentence = "这是一个测试句子。"
source_info, data = extract_source_item_info(source_list, index, sentence)
assert source_info != ""
assert data["title"] == "测试标题"
assert data["url"] == "测试标题"
assert data["content"] == "测试内容"
class TestRemoveTrailingSpacesAndPunctuation:
"""Test cases for remove_trailing_spaces_and_punctuation function."""
def test_remove_trailing_spaces_and_punctuation_basic(self):
"""Test basic functionality of removing trailing spaces and punctuation."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子。 ")
assert result == "这是一个测试句子"
def test_remove_trailing_chinese_punctuation(self):
"""Test removing Chinese punctuation."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子。")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子!")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子?")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子;")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子:")
assert result == "这是一个测试句子"
def test_remove_trailing_english_punctuation(self):
"""Test removing English punctuation."""
result = remove_trailing_spaces_and_punctuation("This is a test sentence. ")
assert result == "This is a test sentence"
result = remove_trailing_spaces_and_punctuation("This is a test sentence!")
assert result == "This is a test sentence"
result = remove_trailing_spaces_and_punctuation("This is a test sentence?")
assert result == "This is a test sentence"
result = remove_trailing_spaces_and_punctuation("This is a test sentence;")
assert result == "This is a test sentence"
result = remove_trailing_spaces_and_punctuation("This is a test sentence:")
assert result == "This is a test sentence"
def test_remove_multiple_trailing_punctuation(self):
"""Test removing multiple trailing punctuation marks."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子。。。 ")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("This is a test sentence!!!")
assert result == "This is a test sentence"
def test_remove_citation_marks(self):
"""Test removing citation marks."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子 [citation: 123]")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子[citation:123] ")
assert result == "这是一个测试句子"
result = remove_trailing_spaces_and_punctuation("这是一个测试句子 [ citation: 456 ]")
assert result == "这是一个测试句子"
def test_no_trailing_punctuation(self):
"""Test with no trailing punctuation."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子")
assert result == "这是一个测试句子"
def test_empty_string(self):
"""Test with empty string."""
result = remove_trailing_spaces_and_punctuation("")
assert result == ""
def test_none_input(self):
"""Test with None input."""
result = remove_trailing_spaces_and_punctuation(None)
assert result is None
def test_non_string_input(self):
"""Test with non-string input."""
result = remove_trailing_spaces_and_punctuation(123)
assert result == 123
def test_only_punctuation(self):
"""Test with only punctuation."""
result = remove_trailing_spaces_and_punctuation("。!? ;:")
assert result == ""
def test_mixed_punctuation_and_spaces(self):
"""Test with mixed punctuation and spaces."""
result = remove_trailing_spaces_and_punctuation("这是一个测试句子 !。 ")
assert result == "这是一个测试句子"
class TestEscapeHtmlSpecialChars:
"""Test cases for _escape_html_special_chars function."""
def test_escape_html_special_chars_basic(self):
"""Test basic functionality of escaping HTML special characters."""
result = _escape_html_special_chars("<script>alert('xss')</script>")
assert result == "&lt;script&gt;alert('xss')&lt;/script&gt;"
def test_escape_ampersand(self):
"""Test escaping ampersand character."""
result = _escape_html_special_chars("a & b")
assert result == "a & b"
def test_escape_less_than_greater_than(self):
"""Test escaping less than and greater than characters."""
result = _escape_html_special_chars("5 < 10 > 3")
assert result == "5 < 10 > 3"
def test_escape_quotes(self):
"""Test escaping quote characters."""
result = _escape_html_special_chars('He said "Hello"')
assert result == "He said "Hello""
result = _escape_html_special_chars("It's a test")
assert result == "It's a test"
def test_escape_mixed_special_chars(self):
"""Test escaping mixed special characters."""
input_text = 'He said "5 < 10 & 10 > 5" and it\'s true'
expected = "He said "5 < 10 & 10 > 5" and it's true"
result = _escape_html_special_chars(input_text)
assert result == expected
def test_escape_empty_string(self):
"""Test with empty string."""
result = _escape_html_special_chars("")
assert result == ""
def test_escape_none_input(self):
"""Test with None input."""
result = _escape_html_special_chars(None)
assert result == ""
def test_escape_no_special_chars(self):
"""Test with text that has no special characters."""
result = _escape_html_special_chars("This is a normal text")
assert result == "This is a normal text"
def test_escape_all_special_chars(self):
"""Test escaping all HTML special characters."""
input_text = '& < > " \''
expected = '& < > " ''
result = _escape_html_special_chars(input_text)
assert result == expected