from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from openjiuwen_deepsearch.algorithm.user_feedback_processor.supplementary_search import (
    SupplementaryRewriteContext,
    SupplementarySearcher,
)
from openjiuwen_deepsearch.utils.constants_utils.node_constants import AgentLlmName

@pytest.mark.asyncio
async def test_supplementary_search_selected_only_replaces_only_span_and_preserves_metadata():
    report = "# 标题\n\n## 第二章\n前缀选中内容[checked_citation:0][[1]](https://a.com)后缀\n"
    selected_text = "选中内容[checked_citation:0][[1]](https://a.com)"
    start_offset = report.index("选中内容")
    end_offset = start_offset + len(selected_text)
    citation_start_offset = report.index("[[1]]")
    citation_end_offset = citation_start_offset + len("[[1]](https://a.com)")
    original_citation_messages = {
        "data": [
            {
                "id": 0,
                "reference_index": 1,
                "url": "https://a.com",
                "citation_start_offset": citation_start_offset,
                "citation_end_offset": citation_end_offset,
            }
        ]
    }
    original_infer_messages = [{"id": 3, "content": "保留推理"}]

    searcher = SupplementarySearcher(llm_model_name="mock")
    with patch.object(
        searcher,
        "_build_research_task",
        new_callable=AsyncMock,
        return_value="补充市场规模数据",
    ), patch.object(
        searcher,
        "_run_collection",
        new_callable=AsyncMock,
        return_value={"info_summary": "补充摘要"},
    ), patch.object(
        searcher,
        "_rewrite_selected_only",
        new_callable=AsyncMock,
        return_value="选中内容已补充",
    ) as mock_rewrite_only:
        result = await searcher.supplementary_search(
            feedback={
                "action": "supplementary_search",
                "selected_text": selected_text,
                "start_offset": start_offset,
                "end_offset": end_offset,
                "user_instruction": "补充这一段的信息",
            },
            final_result={
                "response_content": report,
                "citation_messages": original_citation_messages,
                "infer_messages": original_infer_messages,
            },
            language="zh-CN",
        )

    mock_rewrite_only.assert_awaited_once()
    assert result["new_report"] == "# 标题\n\n## 第二章\n前缀选中内容已补充后缀\n"
    assert result["rewritten_text"] == "选中内容已补充"


@pytest.mark.asyncio
async def test_supplementary_search_selected_and_related_preserves_metadata():
    report = "# 标题\n\n## 第二章\n这里是选中内容[checked_citation:0][[1]](https://a.com)\n更多内容\n"
    selected_text = "选中内容[checked_citation:0][[1]](https://a.com)"
    start_offset = report.index("选中内容")
    end_offset = start_offset + len(selected_text)
    citation_start_offset = report.index("[[1]]")
    citation_end_offset = citation_start_offset + len("[[1]](https://a.com)")
    original_citation_messages = {
        "data": [
            {
                "id": 0,
                "reference_index": 1,
                "url": "https://a.com",
                "citation_start_offset": citation_start_offset,
                "citation_end_offset": citation_end_offset,
            }
        ]
    }
    original_infer_messages = []

    searcher = SupplementarySearcher(llm_model_name="mock")
    with patch.object(
        searcher,
        "_build_research_task",
        new_callable=AsyncMock,
        return_value="补充市场规模数据",
    ), patch.object(
        searcher,
        "_run_collection",
        new_callable=AsyncMock,
        return_value={"info_summary": "补充摘要"},
    ), patch.object(
        searcher,
        "_rewrite_selected_and_related",
        new_callable=AsyncMock,
        return_value="## 第二章\n新章节内容",
    ) as mock_rewrite_related:
        result = await searcher.supplementary_search(
            feedback={
                "action": "supplementary_search",
                "rewrite_scope": "selected_and_related",
                "selected_text": selected_text,
                "start_offset": start_offset,
                "end_offset": end_offset,
                "user_instruction": "补充这一段的信息",
            },
            final_result={
                "response_content": report,
                "citation_messages": original_citation_messages,
                "infer_messages": original_infer_messages,
            },
            language="zh-CN",
        )

    mock_rewrite_related.assert_awaited_once()
    assert result["new_report"] == "# 标题\n\n## 第二章\n新章节内容"
    assert result["rewritten_text"] == "## 第二章\n新章节内容"


@pytest.mark.asyncio
async def test_supplementary_search_selected_and_related_keeps_newline_before_next_heading():
    """Ensure selected_and_related rewrites preserve the heading boundary.

    Args:
        None.

    Returns:
        None.
    """
    report = "# 标题\n\n## 第一章\n这里是选中内容\n本章结尾。\n\n## 第二章\n后续内容\n"
    selected_text = "选中内容"
    start_offset = report.index(selected_text)
    end_offset = start_offset + len(selected_text)

    searcher = SupplementarySearcher(llm_model_name="mock")
    with patch.object(
        searcher,
        "_build_research_task",
        new_callable=AsyncMock,
        return_value="补充市场规模数据",
    ), patch.object(
        searcher,
        "_run_collection",
        new_callable=AsyncMock,
        return_value={"info_summary": "补充摘要"},
    ), patch.object(
        searcher,
        "_rewrite_selected_and_related",
        new_callable=AsyncMock,
        return_value="## 第一章\n改写后的章节内容",
    ):
        result = await searcher.supplementary_search(
            feedback={
                "action": "supplementary_search",
                "rewrite_scope": "selected_and_related",
                "selected_text": selected_text,
                "start_offset": start_offset,
                "end_offset": end_offset,
                "user_instruction": "补充这一段的信息",
            },
            final_result={
                "response_content": report,
                "citation_messages": {"data": []},
                "infer_messages": [],
            },
            language="zh-CN",
        )

    assert "改写后的章节内容\n\n## 第二章" in result["new_report"]
    assert "改写后的章节内容\n## 第二章" not in result["new_report"]
    assert "改写后的章节内容## 第二章" not in result["new_report"]


@pytest.mark.asyncio
async def test_run_collection_returns_summary_and_doc_infos_from_collector_service():
    searcher = SupplementarySearcher(llm_model_name="mock")

    fake_session = MagicMock()
    fake_session.get_global_state.side_effect = lambda key: {
        "search_context.feedback_interaction_count": 2,
        "config.info_collector_initial_search_query_count": 2,
        "config.info_collector_max_research_loops": 1,
        "config.info_collector_max_react_recursion_limit": 4,
    }.get(key)

    service_result = MagicMock()
    service_result.info_summary = "补充摘要"
    service_result.doc_infos = [{"title": "官方新闻稿"}]

    with patch(
        "openjiuwen_deepsearch.algorithm.user_feedback_processor.supplementary_search."
        "_resolve_session_collector",
        return_value=fake_session,
    ), patch(
        "openjiuwen_deepsearch.algorithm.user_feedback_processor.supplementary_search."
        "_resolve_model_context_collector",
        return_value=None,
    ), patch(
        "openjiuwen_deepsearch.algorithm.user_feedback_processor.supplementary_search."
        "CollectorExecutionService.run_plan",
        new=AsyncMock(return_value=service_result),
    ) as mock_run_plan:
        result = await searcher._run_collection("补充官方数据", "zh-CN")

    mock_run_plan.assert_awaited_once()
    assert result == {
        "info_summary": "补充摘要",
        "doc_infos": [{"title": "官方新闻稿"}],
    }


@pytest.mark.asyncio
@pytest.mark.parametrize(
    ("method_name", "prompt_name", "agent_name"),
    [
        (
            "_rewrite_selected_only",
            "supplementary_search_rewrite_selected_only",
            AgentLlmName.USER_FEEDBACK_PROCESSOR_SUPPLEMENTARY_SEARCH_REWRITE_SELECTED_ONLY.value,
        ),
        (
            "_rewrite_selected_and_related",
            "supplementary_search_rewrite_selected_and_related",
            AgentLlmName.USER_FEEDBACK_PROCESSOR_SUPPLEMENTARY_SEARCH_REWRITE_SELECTED_AND_RELATED.value,
        ),
    ],
)
async def test_rewrite_methods_accept_named_context_and_forward_prompt_vars(method_name, prompt_name, agent_name):
    """验证补充搜索改写方法会转发 prompt 变量和完整 agent_name。

    Args:
        method_name: 待测试的改写方法名。
        prompt_name: 期望传给底层 prompt 调用的模板名称。
        agent_name: 期望传给底层 LLM 调用的完整 agent_name。

    Returns:
        None.
    """
    searcher = SupplementarySearcher(llm_model_name="mock")
    rewrite_context = SupplementaryRewriteContext(
        user_instruction="补充数据",
        selected_text_clean="选中文本",
        section_text_clean="章节文本",
        collector_summary="摘要",
        doc_infos=[
            {
                "doc_id": "web_1",
                "source_id": "web_1_p123",
                "title": "doc",
                "url": "https://example.com",
                "publish_time": "2025-05",
                "original_content": "原文",
                "content_ref": {"type": "source_store", "source_id": "web_1_p123"},
                "scores": {"authority": 8},
                "key_passages": ["关键段落"],
            }
        ],
        language="zh-CN",
    )

    with patch.object(
        searcher,
        "_invoke_prompt",
        new_callable=AsyncMock,
        return_value=" 改写结果 ",
    ) as mock_invoke_prompt:
        result = await getattr(searcher, method_name)(rewrite_context)

    assert result == "改写结果"
    mock_invoke_prompt.assert_awaited_once_with(
        prompt_name,
        {
            "language": "zh-CN",
            "user_instruction": "补充数据",
            "selected_text_clean": "选中文本",
            "section_text_clean": "章节文本",
            "collector_summary": "摘要",
            "doc_infos": [
                {
                    "doc_time": "2025-05",
                    "source_authority": "",
                    "task_relevance": "",
                    "original_content": "原文",
                    "url": "https://example.com",
                    "information_richness": "",
                    "data_density": "",
                    "title": "doc",
                    "query": "",
                }
            ],
        },
        agent_name,
    )