import json
import logging
from unittest.mock import patch, MagicMock, AsyncMock
import pytest
from openjiuwen_deepsearch.algorithm.report import table_caption_utils
from openjiuwen_deepsearch.algorithm.report.report import Reporter, _get_classified_infos
from openjiuwen_deepsearch.algorithm.report.table_caption_utils import ensure_markdown_table_captions
from openjiuwen_deepsearch.common.common_constants import CHINESE, ENGLISH
from openjiuwen_deepsearch.utils.constants_utils.session_contextvars import session_context
def _classified_doc(title: str, url: str, source_id: str, relevance: float) -> dict:
return {
"title": title,
"url": url,
"source_id": source_id,
"original_content": f"{title} content",
"scores": {"relevance": relevance, "answerability": 0, "authority": 0, "data_density": 0},
}
def _report_doc(idx: int, *, url: str | None = None, content: str | None = None) -> dict:
return {
"title": f"doc-{idx}",
"url": url or f"https://example.com/{idx}",
"original_content": content or f"content-{idx}",
"scores": {"relevance": 9, "answerability": 9, "authority": 9, "data_density": 9},
}
def _centered_caption(caption_text: str) -> str:
return f'<div style="text-align: center;">\n\n**{caption_text}**\n\n</div>'
def test_ensure_markdown_table_captions_adds_contextual_caption():
content = """# 2 整车制造格局
下表汇总了合肥主要整车企业的产能与技术路线:
| 企业 | 产能 | 技术路线 |
|---|---|---|
| 比亚迪合肥基地 | 132万辆 | 纯电/混动 |
表后分析继续展开。
"""
result = ensure_markdown_table_captions(content, CHINESE, 2)
assert _centered_caption("表2-1:合肥主要整车企业的产能与技术路线") in result
assert "表2-1汇总了合肥主要整车企业的产能与技术路线:" in result
assert result.count("表2-1") == 2
def test_ensure_markdown_table_captions_keeps_existing_caption_and_counts_order():
content = """# 3 核心零部件产业链
| 企业 | 产品 |
|---|---|
| 国轩高科 | 动力电池 |
核心零部件企业与产品
| 领域 | 代表企业 |
|---|---|
| 电池 | 国轩高科 |
"""
result = ensure_markdown_table_captions(content, CHINESE, 3)
assert result.count("表3-1:核心零部件企业与产品") == 1
assert _centered_caption("表3-1:核心零部件企业与产品") in result
assert "表3-1梳理了核心零部件企业与产品:" in result
assert "表3-2梳理了核心零部件产业链(领域、代表企业):" in result
assert _centered_caption("表3-2:核心零部件产业链(领域、代表企业)") in result
def test_ensure_markdown_table_captions_rewrites_post_table_reference():
content = """# 3 AI芯片
| 类型 | 特征 |
|---|---|
| 端侧SoC | 低功耗 |
端侧SoC与云端AI芯片差异
如上表所示,端侧SoC更强调能效与集成度。
"""
result = ensure_markdown_table_captions(content, CHINESE, 3)
assert result.count("表3-1:端侧SoC与云端AI芯片差异") == 1
assert "如表3-1所示,端侧SoC更强调能效与集成度。" in result
assert "如上表所示" not in result
def test_ensure_markdown_table_captions_rewrites_previous_reference_within_five_context_lines():
content = """# 4 产业链韧性
下表汇总了关键环节的本地配套进展:
该判断还需要结合企业落地节奏观察。
政策侧的支持力度也会影响后续兑现。
| 环节 | 进展 |
|---|---|
| 功率半导体 | 已导入头部车企 |
"""
result = ensure_markdown_table_captions(content, CHINESE, 4)
assert "表4-1汇总了关键环节的本地配套进展:" in result
assert "下表汇总了关键环节的本地配套进展" not in result
assert result.count("表4-1") == 2
def test_ensure_markdown_table_captions_rewrites_next_reference_within_five_context_lines():
content = """# 5 市场格局
| 品类 | 份额 |
|---|---|
| 新能源乘用车 | 42% |
新能源乘用车市场份额
这一结构体现出头部品类的领先优势。
后续仍需关注渗透率变化。
如上表所示,新能源乘用车仍是核心增长来源。
"""
result = ensure_markdown_table_captions(content, CHINESE, 5)
assert "如表5-1所示,新能源乘用车仍是核心增长来源。" in result
assert "如上表所示" not in result
assert "\n\n表5-1梳理了" not in result
def test_ensure_markdown_table_captions_inserts_intro_when_reference_missing():
content = """# 6 产业生态
这一目标的实现,离不开当前已构建的五位一体生态体系。
| 支撑维度 | 具体内容 | 数据/案例 |
|---|---|---|
| 充换电设施 | 充电枪、换电站 | 35万个充电枪 |
综上,合肥通过系统性基础设施布局构建了产业生态圈。
"""
result = ensure_markdown_table_captions(content, CHINESE, 6)
assert "表6-1梳理了产业生态(支撑维度、具体内容、数据/案例):" in result
assert _centered_caption("表6-1:产业生态(支撑维度、具体内容、数据/案例)") in result
assert result.count("表6-1") == 2
assert "综上,合肥通过系统性基础设施布局构建了产业生态圈。" in result
def test_ensure_markdown_table_captions_merges_intro_into_colon_line():
content = """# 2 EDA软件与材料领域:自主化进程与供应链安全评估
根据权威梳理,中国在多个核心材料品类已实现进口替代:
| 材料类别 | 国内唯一性突破 | 关键应用与客户 |
|---|---|---|
| 大硅片 | 12英寸量产 | 14nm逻辑芯片 |
"""
result = ensure_markdown_table_captions(content, CHINESE, 2)
assert (
"根据权威梳理,中国在多个核心材料品类已实现进口替代,"
"表2-1梳理了EDA软件与材料领域:自主化进程与供应链安全评估"
"(材料类别、国内唯一性突破、关键应用与客户):"
) in result
assert "\n\n表2-1梳理了" not in result
assert _centered_caption(
"表2-1:EDA软件与材料领域:自主化进程与供应链安全评估(材料类别、国内唯一性突破、关键应用与客户)"
) in result
def test_ensure_markdown_table_captions_moves_caption_below_table():
content = """# 2 产业对比
表2-1:产业指标对比
| 指标 | 数值 |
|---|---|
| 产量 | 100 |
结论延续。
"""
result = ensure_markdown_table_captions(content, CHINESE, 2)
caption = _centered_caption("表2-1:产业指标对比")
assert result.count(caption) == 1
assert result.index("| 指标 | 数值 |") < result.index(caption)
def test_ensure_markdown_table_captions_normalizes_plain_caption_below_table():
content = """# 4 供应链结构
| 环节 | 企业 |
|---|---|
| 电池 | 国轩高科 |
表4-1:供应链代表企业
如上表所示,电池环节已有龙头企业支撑。
"""
result = ensure_markdown_table_captions(content, CHINESE, 4)
caption = _centered_caption("表4-1:供应链代表企业")
assert result.count(caption) == 1
assert result.splitlines().count("表4-1:供应链代表企业") == 0
assert "如表4-1所示,电池环节已有龙头企业支撑。" in result
def test_ensure_markdown_table_captions_normalizes_unnumbered_title_below_table():
content = """# 7 创新生态
| 平台 | 方向 |
|---|---|
| 科研院所 | 电池材料 |
表格标题:创新生态科研平台布局
如上表所示,科研平台支撑技术转化。
"""
result = ensure_markdown_table_captions(content, CHINESE, 7)
caption = _centered_caption("表7-1:创新生态科研平台布局")
assert result.count(caption) == 1
assert "表格标题:创新生态科研平台布局" not in result
assert "如表7-1所示,科研平台支撑技术转化。" in result
def test_ensure_markdown_table_captions_normalizes_plain_llm_title_below_table():
content = """# 1 产业发展历程
| 时间 | 事件 |
|---|---|
| 2010年 | 产业链起步 |
2010-2021年间合肥汽车产业链发展的关键节点
如上表所示,合肥汽车产业链在关键节点上持续扩展。
"""
result = ensure_markdown_table_captions(content, CHINESE, 1)
caption = _centered_caption("表1-1:2010-2021年间合肥汽车产业链发展的关键节点")
assert result.count(caption) == 1
assert result.splitlines().count("2010-2021年间合肥汽车产业链发展的关键节点") == 0
assert "如表1-1所示,合肥汽车产业链在关键节点上持续扩展。" in result
def test_ensure_markdown_table_captions_accepts_comma_in_plain_title():
content = """# 1 Market structure
| Company | Share |
|---|---|
| A | 40% |
Company, share, and growth
As shown in the table above, company A keeps a leading position.
"""
result = ensure_markdown_table_captions(content, ENGLISH, 1)
caption = _centered_caption("Table 1-1: Company, share, and growth")
assert result.count(caption) == 1
assert result.splitlines().count("Company, share, and growth") == 0
assert "As shown in the table above" not in result
assert "As shown in Table 1-1" in result
def test_ensure_markdown_table_captions_rewrites_english_table_below_reference():
content = """# 2 Market structure
The table below summarizes company share:
| Company | Share |
|---|---|
| A | 40% |
"""
result = ensure_markdown_table_captions(content, ENGLISH, 2)
assert "Table 2-1 summarizes company share:" in result
assert "The table below" not in result
def test_ensure_markdown_table_captions_rewrites_english_from_table_above():
content = """# 2 Market structure
| Company | Share |
|---|---|
| A | 40% |
Company share
From the table above, company A keeps a leading position.
"""
result = ensure_markdown_table_captions(content, ENGLISH, 2)
assert "From Table 2-1, company A keeps a leading position." in result
assert "from the table above" not in result.lower()
def test_ensure_markdown_table_captions_rewrites_english_weak_below_reference():
content = """# 2 Market structure
The comparison is below:
| Company | Share |
|---|---|
| A | 40% |
"""
result = ensure_markdown_table_captions(content, ENGLISH, 2)
assert "The comparison is Table 2-1 below:" in result
def test_ensure_markdown_table_captions_prefers_plain_llm_title_over_intro():
content = """# 4 配套体系
下表总结了合肥主要链主企业及其吸引的部分长三角配套企业情况:
| 链主企业 | 配套领域 |
|---|---|
| 比亚迪 | 内外饰、座椅 |
合肥链主企业与长三角配套企业关系
后续分析继续展开。
"""
result = ensure_markdown_table_captions(content, CHINESE, 4)
assert _centered_caption("表4-1:合肥链主企业与长三角配套企业关系") in result
assert _centered_caption("表4-1:合肥主要链主企业及其吸引的部分长三角配套企业情况") not in result
assert "表4-1总结了合肥主要链主企业及其吸引的部分长三角配套企业情况:" in result
def test_ensure_markdown_table_captions_keeps_title_with_driving_effect():
content = """# 1 产业基础
为了量化分析合肥主要整车企业及其对供应链的带动作用,下表进行了总结:
| 整车企业 | 引入/强化时间 | 核心特点 | 对供应链的主要带动作用 |
|---|---|---|---|
| 江淮汽车 | 1964年成立 | 本土老牌车企 | 培育早期零部件产业基础 |
合肥主要整车企业及其供应链带动作用
"""
result = ensure_markdown_table_captions(content, CHINESE, 1)
assert _centered_caption("表1-1:合肥主要整车企业及其供应链带动作用") in result
assert "合肥主要整车企业及其供应链带动作用" not in [
line.strip() for line in result.splitlines()
]
assert "为了量化分析合肥主要整车企业及其对供应链的带动作用,表1-1进行了总结:" in result
def test_ensure_markdown_table_captions_does_not_swallow_punctuated_narrative():
content = """# 3 产业集中度
| 企业 | 份额 |
|---|---|
| A企业 | 40% |
这一数据反映出产业集中度在提升。
"""
result = ensure_markdown_table_captions(content, CHINESE, 3)
assert "这一数据反映出产业集中度在提升。" in result
assert _centered_caption("表3-1:产业集中度(企业、份额)") in result
def test_ensure_markdown_table_captions_handles_empty_and_none_input():
assert ensure_markdown_table_captions("", CHINESE, 1) == ""
assert ensure_markdown_table_captions(None, CHINESE, 1) is None
def test_ensure_markdown_table_captions_normalizes_section_idx_none_and_zero():
content = """# 编号
| A | B |
|---|---|
| 1 | 2 |
"""
none_result = ensure_markdown_table_captions(content, CHINESE, None)
zero_result = ensure_markdown_table_captions(content, CHINESE, 0)
assert _centered_caption("表1:编号(A、B)") in none_result
assert _centered_caption("表0-1:编号(A、B)") in zero_result
def test_ensure_markdown_table_captions_keeps_plain_text_without_tables():
content = "# 1 纯文本\n\n这里没有任何表格,只是普通正文。"
result = ensure_markdown_table_captions(content, CHINESE, 1)
assert result == content
assert "表1-1" not in result
def test_ensure_markdown_table_captions_ignores_tables_inside_code_fences():
content = """# 1 示例
```python
| A | B |
|---|---|
| 1 | 2 |
```
正文继续。
"""
result = ensure_markdown_table_captions(content, CHINESE, 1)
assert "表1-1" not in result
assert "| A | B |" in result
def test_ensure_markdown_table_captions_warns_on_mismatched_code_fence(caplog):
content = """# 1 示例
```python
~~~
| A | B |
|---|---|
| 1 | 2 |
"""
with caplog.at_level(logging.WARNING, logger=table_caption_utils.__name__):
result = ensure_markdown_table_captions(content, CHINESE, 1)
assert "Mismatched Markdown code fence marker" in caplog.text
assert "表1-1" not in result
def test_ensure_markdown_table_captions_splits_adjacent_tables_without_blank_line():
content = """# 2 指标
| 指标 | 值 |
|---|---|
| A | 1 |
| 维度 | 值 |
|---|---|
| B | 2 |
"""
result = ensure_markdown_table_captions(content, CHINESE, 2)
assert _centered_caption("表2-1:指标(指标、值)") in result
assert _centered_caption("表2-2:指标(维度、值)") in result
assert result.count('<div style="text-align: center;">') == 2
def test_ensure_markdown_table_captions_truncates_long_explicit_caption():
long_title = "核心指标" * 30
content = f"""# 1 长标题
| 指标 | 值 |
|---|---|
| A | 1 |
表格标题:{long_title}
"""
result = ensure_markdown_table_captions(content, CHINESE, 1)
assert long_title not in result
assert result.count("表1-1:") == 1
def test_table_caption_markup_cleaning_keeps_prefix_removal_separate():
text = "**下表总结了[核心指标](https://example.com)**[citation:1]"
assert table_caption_utils.normalize_caption_markup(text) == "下表总结了核心指标"
assert table_caption_utils.clean_caption_text(text) == "核心指标"
def test_table_caption_line_override_keeps_existing_on_conflict(caplog):
overrides = {3: "first rewrite"}
with caplog.at_level(logging.DEBUG, logger=table_caption_utils.__name__):
table_caption_utils._set_line_override(overrides, 3, "second rewrite")
assert overrides[3] == "first rewrite"
assert "Skip conflicting table-reference rewrite" in caplog.text
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.ainvoke_llm_with_stats", new_callable=AsyncMock)
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_generate_sub_report(mock_llm_cls, mock_ainvoke_llm):
mock_session = MagicMock()
mock_session.write_custom_stream = AsyncMock()
token = session_context.set(mock_session)
async def mock_ainvoke_llm_with_stats(llm, messages, llm_type: str = "basic", agent_name="AI", schema=None,
tools=None, need_stream_out=False):
if any("classification" in msg.get("content", "") for msg in messages):
user_content = next(msg.get("content", "") for msg in messages if msg.get("role") == "user")
assert "'original_content': 'fake original_content'" in user_content
assert "'doc_time': '2024 8月'" in user_content
assert "doc_id" not in user_content
assert "source_id" not in user_content
assert "content_ref" not in user_content
assert "scores" not in user_content
assert "key_passages" not in user_content
assert "brief_reason" not in user_content
return {"content": '{\"chapter\": \"企业经营与行业分析\", \"selected_url_list\": [\"fake_url\"]}'}
elif any("subsection outline" in msg.get("content", "") for msg in messages):
return {"content": "3 企业经营与行业分析\n3.1 经营风险评价\3.2 杠杆风险评估"}
elif any("write the chapter" in msg.get("content", "") for msg in messages):
return {"content": "fake subsection report content"}
else:
return {"content": "default response"}
mock_ainvoke_llm.side_effect = mock_ainvoke_llm_with_stats
reporter = Reporter("basic")
current_inputs = dict(
has_template=False,
language=CHINESE,
report_template='',
report_style='scholarly',
section_idx=3,
report_task='XX有限公司尽职调查报告',
section_task='企业经营与行业分析',
section_iscore=True,
section_description='fake section_description',
doc_infos=[{
'doc_id': 'web_1',
'source_id': 'web_1_p123',
'doc_time': '2024 8月',
'publish_time': '2024 8月',
'original_content': 'fake original_content',
'url': 'fake_url',
'title': 'XX有限公司 - 企业详情',
'source': 'local',
'scores': {'authority': 8, 'relevance': 9, 'answerability': 7, 'data_density': 6},
'brief_reason': 'fake reason',
'key_passages': ['fake passage'],
'content_ref': {'type': 'source_store', 'source_id': 'web_1_p123'},
}],
gathered_info=[{'url': 'fake_url', 'title': 'XX有限公司 - 企业详情', 'content': 'fake content'}],
sub_evaluation_details='',
max_generate_retry_num=3,
max_sub_report_evaluate_num=0
)
success, report, sub_report_content, classified_content = await reporter.generate_sub_report(current_inputs)
assert success is True
assert current_inputs["sub_section_core_content"] == ["fake original_content"]
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_classify_doc_infos_returns_selected_url_list(mock_llm_cls):
reporter = Reporter("basic")
reporter._classify_with_llm = AsyncMock(
return_value=(True, '{"chapter": "企业经营与行业分析", "selected_url_list": ["fake_url"]}')
)
current_inputs = {
"section_idx": 3,
"section_task": "企业经营与行业分析",
"doc_infos": [
{
"title": "XX有限公司 - 企业详情",
"url": "fake_url",
"original_content": "fake original_content",
}
],
"classify_doc_infos_single_time_num": 60,
"classify_doc_infos_res_top_k_num": 10,
}
success, classified_content = await reporter._classify_doc_infos(current_inputs)
assert success is True
assert classified_content == {"selected_url_list": ["fake_url"]}
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_classify_doc_infos_preserves_llm_url_order(mock_llm_cls):
reporter = Reporter("basic")
selected_urls = ["https://example.com/order/2", "https://example.com/order/0", "https://example.com/order/1"]
reporter._classify_with_llm = AsyncMock(
return_value=(
True,
json.dumps({"chapter": "chapter", "selected_url_list": selected_urls}),
)
)
docs = [_report_doc(idx, url=url) for idx, url in enumerate(selected_urls)]
success, classified_content = await reporter._classify_doc_infos({
"section_idx": 3,
"section_task": "企业经营与行业分析",
"doc_infos": docs,
"classify_doc_infos_single_time_num": 60,
"classify_doc_infos_res_top_k_num": len(selected_urls),
"classify_doc_infos_prefilter_multiplier": 5,
})
assert success is True
assert classified_content == {"selected_url_list": selected_urls}
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_classify_doc_infos_prefilters_and_keeps_same_url_different_content(mock_llm_cls):
reporter = Reporter("basic")
seen_batch_sizes = []
async def fake_classify(current_inputs, section_task, batch):
seen_batch_sizes.append(len(batch))
same_url_docs = [doc for doc in batch if doc["url"] == "https://example.com/same"]
assert len(same_url_docs) == 2
return True, '{"selected_url_list": ["https://example.com/same"]}'
reporter._classify_with_llm = AsyncMock(side_effect=fake_classify)
docs = []
for idx in range(80):
docs.append({
"title": f"doc-{idx}",
"url": f"https://example.com/{idx}",
"original_content": f"content-{idx}",
"plan_idx": 0,
"step_idx": idx % 4,
"scores": {"relevance": idx % 10, "answerability": 9, "authority": 8, "data_density": 7},
})
docs[0]["url"] = "https://example.com/same"
docs[0]["original_content"] = "variant A"
docs[0]["scores"]["relevance"] = 10
docs[1]["url"] = "https://example.com/same"
docs[1]["original_content"] = "variant B"
docs[1]["scores"]["relevance"] = 10
success, classified_content = await reporter._classify_doc_infos({
"section_idx": 3,
"section_task": "企业经营与行业分析",
"doc_infos": docs,
"classify_doc_infos_single_time_num": 60,
"classify_doc_infos_res_top_k_num": 10,
"classify_doc_infos_prefilter_multiplier": 5,
})
assert success is True
assert classified_content == {"selected_url_list": ["https://example.com/same"]}
assert seen_batch_sizes == [50]
def test_get_classified_infos_returns_all_distinct_content_variants_for_selected_url():
doc_infos = [
{"title": "A", "url": "https://example.com/same", "original_content": "variant A"},
{"title": "A", "url": "https://example.com/same", "original_content": "variant B"},
{"title": "B", "url": "https://example.com/other", "original_content": "other"},
]
classified_infos, classified_doc_infos = _get_classified_infos(
doc_infos,
["https://example.com/same"],
)
assert classified_infos["references"] == ["[A](https://example.com/same)"]
assert classified_infos["core_content_list"] == ["variant A", "variant B"]
assert classified_doc_infos == doc_infos[:2]
def test_get_classified_infos_deduplicates_same_content_without_source_id():
doc_infos = [
{"title": "A low", "url": "https://example.com/same", "original_content": "same content", "scores": {"relevance": 1}},
{"title": "A high", "url": "https://example.com/same", "original_content": "same content", "scores": {"relevance": 9}},
]
classified_infos, classified_doc_infos = _get_classified_infos(
doc_infos,
["https://example.com/same"],
)
assert classified_infos["core_content_list"] == ["same content"]
assert classified_doc_infos == [doc_infos[1]]
def test_get_classified_infos_keeps_top10_source_ids_by_score():
doc_infos = [
_classified_doc(f"doc-{idx}", "https://example.com/same", f"source-{idx}", idx * 0.8)
for idx in range(12)
]
classified_infos, classified_doc_infos = _get_classified_infos(
doc_infos,
["https://example.com/same"],
max_source_id_count=10,
)
assert len(classified_doc_infos) == 10
assert {doc["source_id"] for doc in classified_doc_infos} == {
f"source-{idx}" for idx in range(2, 12)
}
assert classified_doc_infos[0]["source_id"] == "source-11"
assert len(classified_infos["core_content_list"]) == 10
assert classified_infos["references"] == ["[doc\\-11](https://example.com/same)"]
def test_get_classified_infos_keeps_each_selected_url_before_filling_variants():
doc_infos = [
_classified_doc("A-0", "https://example.com/a", "a-0", 10),
_classified_doc("A-1", "https://example.com/a", "a-1", 9),
_classified_doc("B", "https://example.com/b", "b-0", 1),
]
classified_infos, classified_doc_infos = _get_classified_infos(
doc_infos,
["https://example.com/a", "https://example.com/b"],
max_source_id_count=2,
)
assert [doc["url"] for doc in classified_doc_infos] == ["https://example.com/a", "https://example.com/b"]
assert classified_infos["references"] == [
"[A\\-0](https://example.com/a)",
"[B](https://example.com/b)",
]
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_classify_doc_infos_fallbacks_when_prefilter_result_returns_empty_urls(mock_llm_cls):
reporter = Reporter("basic")
calls = []
async def fake_classify(current_inputs, section_task, batch):
calls.append(len(batch))
if len(calls) == 1:
return True, '{"selected_url_list": []}'
return True, '{"selected_url_list": ["https://example.com/1"]}'
reporter._classify_with_llm = AsyncMock(side_effect=fake_classify)
success, classified_content = await reporter._classify_doc_infos({
"section_idx": 3,
"section_task": "企业经营与行业分析",
"doc_infos": [
{
"title": "doc",
"url": "https://example.com/1",
"original_content": "content",
"scores": {"relevance": 9, "answerability": 9, "authority": 9, "data_density": 9},
}
],
"classify_doc_infos_single_time_num": 60,
"classify_doc_infos_res_top_k_num": 10,
"classify_doc_infos_prefilter_multiplier": 5,
})
assert success is True
assert classified_content == {"selected_url_list": ["https://example.com/1"]}
assert calls == [1, 1]
@pytest.mark.asyncio
@patch("openjiuwen_deepsearch.algorithm.report.report.ainvoke_llm_with_stats", new_callable=AsyncMock)
@patch("openjiuwen_deepsearch.algorithm.report.report.llm_context", new_callable=MagicMock)
async def test_generate_sub_report_with_background_knowledge_only(mock_llm_cls, mock_ainvoke_llm):
mock_session = MagicMock()
mock_session.write_custom_stream = AsyncMock()
token = session_context.set(mock_session)
async def mock_ainvoke_llm_with_stats(llm, messages, llm_type: str = "basic", agent_name="AI", schema=None,
tools=None, need_stream_out=False):
if any("classification" in msg.get("content", "") for msg in messages):
raise AssertionError("classification should not run when doc_infos is empty but background exists")
if any("subsection outline" in msg.get("content", "") for msg in messages):
return {"content": "2 企业经营分析\n2.1 上游章节要点承接\n2.2 当前章节判断"}
if any("write the chapter" in msg.get("content", "") for msg in messages):
return {"content": "fake subsection report content from background knowledge"}
return {"content": "background summary"}
mock_ainvoke_llm.side_effect = mock_ainvoke_llm_with_stats
reporter = Reporter("basic")
current_inputs = dict(
has_template=False,
language=CHINESE,
report_template='',
report_style='scholarly',
section_idx=2,
report_task='XX有限公司尽职调查报告',
section_task='企业经营分析',
section_iscore=False,
section_description='结合父章节摘要继续撰写',
doc_infos=[],
gathered_info=[],
sub_report_background_knowledge=[
{"section_id": "1", "content_summary": "父章节总结:公司主营业务稳定,收入结构清晰。"}
],
sub_evaluation_details='',
max_generate_retry_num=3,
max_sub_report_evaluate_num=0
)
success, report, sub_report_content, classified_content = await reporter.generate_sub_report(current_inputs)
session_context.reset(token)
assert success is True
assert sub_report_content
assert classified_content == []
assert current_inputs["sub_section_core_content"] == [
"[Parent Section 1] 父章节总结:公司主营业务稳定,收入结构清晰。"
]