from unittest.mock import AsyncMock, patch

import pytest

from openjiuwen_deepsearch.algorithm.user_feedback_processor.report_edit_utils import strip_markup_in_range
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.algorithm.user_feedback_processor.synonym_rewrite import (
    ACTION_TO_PROMPT,
    SYNONYM_REWRITE_ACTIONS,
    SynonymRewriter,
)


SYNONYM_REWRITER_MODULE_PATH = "openjiuwen_deepsearch.algorithm.user_feedback_processor.synonym_rewrite"


class TestCitationHelpers:
    def test_strip_citations_only_within_selected_range(self):
        report = (
            "前缀[checked_citation:0][[1]](https://a.com)"
            "这是要改写的段落[checked_citation:1][[2]](https://b.com)结束"
            "[checked_citation:2][[3]](https://c.com)尾部"
        )
        start = report.index("这是要改写的段落")
        end = start + len("这是要改写的段落[checked_citation:1][[2]](https://b.com)结束")
        removed_start = report.index("[checked_citation:1]")
        removed_end = removed_start + len("[checked_citation:1][[2]](https://b.com)")

        stripped_text, removed_ranges, removed_inference_ids = strip_markup_in_range(report, start, end)

        assert removed_ranges == {(removed_start, removed_end)}
        assert removed_inference_ids == []
        assert "[checked_citation:0][[1]]" in stripped_text
        assert "[checked_citation:2][[3]]" in stripped_text
        assert "[checked_citation:1][[2]]" not in stripped_text

    def test_strip_citations_keeps_legacy_citation_format_when_trace_source_is_disabled(self):
        report = "前缀这是要改写的段落[citation: 2]结束[citation: 3]尾部"
        start = report.index("这是要改写的段落")
        end = start + len("这是要改写的段落[citation: 2]结束")
        removed_start = report.index("[citation: 2]")
        removed_end = removed_start + len("[citation: 2]")

        stripped_text, removed_ranges, removed_inference_ids = strip_markup_in_range(report, start, end)

        assert removed_ranges == {(removed_start, removed_end)}
        assert removed_inference_ids == []
        assert "[citation: 2]" not in stripped_text
        assert "[citation: 3]" in stripped_text

    def test_strip_citations_removes_multiple_citations_within_selection(self):
        report = (
            "前缀"
            "需要改写[checked_citation:1][[2]](https://b.com)的段落[checked_citation:2][[3]](https://c.com)内容"
            "尾部[checked_citation:3][[4]](https://d.com)"
        )
        start = report.index("需要改写")
        end = report.index("内容") + len("内容")

        stripped_text, removed_ranges, removed_inference_ids = strip_markup_in_range(report, start, end)

        assert len(removed_ranges) == 2
        assert removed_inference_ids == []
        assert "[checked_citation:1][[2]]" not in stripped_text
        assert "[checked_citation:2][[3]]" not in stripped_text
        assert "[checked_citation:3][[4]]" in stripped_text

    def test_strip_citations_rewrites_inference_markers_to_plain_text_and_collects_ids(self):
        report = "前缀[结论A](#inference:1)中间[结论B](#inference:3)尾部"
        start = report.index("[结论A]")
        end = report.index("尾部")

        stripped_text, removed_ranges, removed_inference_ids = strip_markup_in_range(report, start, end)

        assert removed_ranges == set()
        assert removed_inference_ids == [1, 3]
        assert stripped_text == "前缀结论A中间结论B尾部"

class TestSynonymRewriter:
    @pytest.fixture
    def synonym_rewriter(self):
        return SynonymRewriter(llm_model_name="mock_model")

    def test_prompt_names_match_registered_actions(self, synonym_rewriter):
        assert SYNONYM_REWRITE_ACTIONS == frozenset(ACTION_TO_PROMPT.keys())
        assert synonym_rewriter.get_prompt_name("expand") == "synonym_rewrite_expand"
        assert synonym_rewriter.get_prompt_name("polish") == "synonym_rewrite_polish"
        assert synonym_rewriter.get_prompt_name("shorten") == "synonym_rewrite_shorten"

    def test_invalid_prompt_name_raises(self, synonym_rewriter):
        with pytest.raises(CustomValueException) as exc_info:
            synonym_rewriter.get_prompt_name("unknown_action")

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code

    @pytest.mark.asyncio
    async def test_generate_synonym_rewrite_text_calls_llm_wrapper(self, synonym_rewriter):
        with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.apply_system_prompt") as mock_prompt:
            mock_prompt.return_value = [{"role": "system", "content": "prompt"}]
            with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.get_llm_instance", return_value=object()):
                with patch(
                    f"{SYNONYM_REWRITER_MODULE_PATH}.ainvoke_llm_with_stats",
                    new_callable=AsyncMock,
                ) as mock_ainvoke:
                    mock_ainvoke.return_value = {"content": "这是改写后的文本内容"}

                    result = await synonym_rewriter._generate_synonym_rewrite_text(
                        action="expand",
                        original_text="这是原文",
                        language="zh-CN",
                        user_instruction="请详细展开",
                    )

        assert result == "这是改写后的文本内容"
        mock_prompt.assert_called_once()
        mock_ainvoke.assert_awaited_once()

    @pytest.mark.asyncio
    async def test_generate_synonym_rewrite_text_wraps_unknown_exception(self, synonym_rewriter):
        with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.apply_system_prompt", return_value=[{"role": "system", "content": "prompt"}]):
            with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.get_llm_instance", return_value=object()):
                with patch(
                    f"{SYNONYM_REWRITER_MODULE_PATH}.ainvoke_llm_with_stats",
                    new_callable=AsyncMock,
                    side_effect=RuntimeError("boom"),
                ):
                    with pytest.raises(CustomValueException) as exc_info:
                        await synonym_rewriter._generate_synonym_rewrite_text(
                            action="expand",
                            original_text="这是原文",
                            language="zh-CN",
                            user_instruction="",
                        )

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_REWRITE_ERROR.code


class TestRewriteFlow:
    @pytest.fixture
    def synonym_rewriter(self):
        return SynonymRewriter(llm_model_name="mock_model")

    @pytest.mark.asyncio
    async def test_synonym_rewrite_uses_generated_text(self, synonym_rewriter):
        with patch.object(synonym_rewriter, "_generate_synonym_rewrite_text", new_callable=AsyncMock) as mock_generate:
            mock_generate.return_value = "改写后的文本"

            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "expand",
                    "selected_text": "原文",
                    "start_offset": 0,
                    "end_offset": 2,
                    "user_instruction": "",
                },
                report_content="原文后续内容",
                language="zh-CN",
            )

        assert result["rewritten_text"] == "改写后的文本"
        assert result["new_report"] == "改写后的文本后续内容"

    @pytest.mark.asyncio
    async def test_synonym_rewrite_strips_legacy_citation_format_without_source_tracer_metadata(self, synonym_rewriter):
        report = "前缀需要改写[citation: 2]的段落尾部"
        selected = "需要改写[citation: 2]的段落"
        start = report.index("需要改写")
        end = start + len(selected)

        with patch.object(synonym_rewriter, "_generate_synonym_rewrite_text", new_callable=AsyncMock) as mock_generate:
            mock_generate.return_value = "改写后的段落"

            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "expand",
                    "selected_text": selected,
                    "start_offset": start,
                    "end_offset": end,
                    "user_instruction": "",
                },
                report_content=report,
                language="zh-CN",
            )

        mock_generate.assert_awaited_once_with(
            action="expand",
            original_text="需要改写的段落",
            language="zh-CN",
            user_instruction="",
        )
        assert result["new_report"] == "前缀改写后的段落尾部"

    @pytest.mark.asyncio
    async def test_full_rewrite_expand_flow_keeps_citation_messages_intact(self):
        report = (
            "引言部分。[checked_citation:0][[1]](https://a.com)"
            "这是需要扩写的段落内容。[checked_citation:1][[2]](https://b.com)结论部分。"
        )
        selected = "这是需要扩写的段落内容。[checked_citation:1][[2]](https://b.com)"
        start = report.index("这是需要扩写的段落内容。")
        end = start + len(selected)
        selected_citation = "[checked_citation:1][[2]](https://b.com)"
        feedback = {
            "action": "expand",
            "selected_text": selected,
            "start_offset": start,
            "end_offset": end,
            "user_instruction": "补充技术细节",
        }

        with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.apply_system_prompt", return_value=[{"role": "system", "content": "prompt"}]):
            with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.get_llm_instance", return_value=object()):
                with patch(
                    f"{SYNONYM_REWRITER_MODULE_PATH}.ainvoke_llm_with_stats",
                    new_callable=AsyncMock,
                ) as mock_ainvoke:
                    mock_ainvoke.return_value = {"content": "这是经过详细扩写后的段落内容,补充了技术细节和实现方案。"}

                    synonym_rewriter = SynonymRewriter(llm_model_name="mock")
                    result = await synonym_rewriter.synonym_rewrite(
                        feedback=feedback,
                        report_content=report,
                        language="zh-CN",
                    )

        assert "引言部分" in result["new_report"]
        assert "结论部分" in result["new_report"]
        assert "[checked_citation:0][[1]]" in result["new_report"]
        assert "[checked_citation:1][[2]]" not in result["new_report"]
        assert result["rewritten_text"] == "这是经过详细扩写后的段落内容,补充了技术细节和实现方案。"
        assert result["rewritten_start_offset"] == start
        assert result["rewritten_end_offset"] == start + len(result["rewritten_text"])

    @pytest.mark.asyncio
    async def test_full_rewrite_expand_flow_keeps_multiple_citation_messages_intact(self):
        report = (
            "前言"
            "需要扩写[checked_citation:0][[2]](https://b.com)的段落[checked_citation:1][[3]](https://c.com)内容"
            "尾注[checked_citation:2][[4]](https://d.com)"
        )
        selected = "需要扩写[checked_citation:0][[2]](https://b.com)的段落[checked_citation:1][[3]](https://c.com)内容"
        start = report.index("需要扩写")
        end = start + len(selected)
        citation_token_2 = "[checked_citation:0][[2]](https://b.com)"
        citation_token_3 = "[checked_citation:1][[3]](https://c.com)"
        citation_token_4 = "[checked_citation:2][[4]](https://d.com)"

        with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.apply_system_prompt", return_value=[{"role": "system", "content": "prompt"}]):
            with patch(f"{SYNONYM_REWRITER_MODULE_PATH}.get_llm_instance", return_value=object()):
                with patch(
                    f"{SYNONYM_REWRITER_MODULE_PATH}.ainvoke_llm_with_stats",
                    new_callable=AsyncMock,
                ) as mock_ainvoke:
                    mock_ainvoke.return_value = {"content": "扩写后的段落内容"}

                    synonym_rewriter = SynonymRewriter(llm_model_name="mock")
                    result = await synonym_rewriter.synonym_rewrite(
                        feedback={
                            "action": "expand",
                            "selected_text": selected,
                            "start_offset": start,
                            "end_offset": end,
                            "user_instruction": "",
                        },
                        report_content=report,
                        language="zh-CN",
                    )

        assert citation_token_2 not in result["new_report"]
        assert citation_token_3 not in result["new_report"]
        assert citation_token_4 in result["new_report"]

    @pytest.mark.asyncio
    async def test_synonym_rewrite_preserves_trailing_citation_offsets(self, synonym_rewriter):
        trailing_citation = "[checked_citation:0][[4]](https://d.com)"
        report = f"前言需要精简的段落内容尾注{trailing_citation}"
        selected = "需要精简的段落内容"
        start = report.index(selected)
        end = start + len(selected)

        with patch.object(synonym_rewriter, "_generate_synonym_rewrite_text", new_callable=AsyncMock) as mock_generate:
            mock_generate.return_value = "精简后"

            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "shorten",
                    "selected_text": selected,
                    "start_offset": start,
                    "end_offset": end,
                    "user_instruction": "",
                },
                report_content=report,
                language="zh-CN",
            )

        assert result["new_report"] == f"前言精简后尾注{trailing_citation}"

    @pytest.mark.asyncio
    async def test_synonym_rewrite_keeps_unselected_duplicate_reference_instances(self, synonym_rewriter):
        report = (
            "first shared citation [checked_citation:0][[1]](https://shared.com) needs rewrite. "
            "later there is another shared citation [checked_citation:1][[1]](https://shared.com) that should remain."
        )
        selected = "first shared citation [checked_citation:0][[1]](https://shared.com) needs rewrite."
        start = report.index(selected)
        end = start + len(selected)
        first_citation = "[checked_citation:0][[1]](https://shared.com)"
        second_citation = "[checked_citation:1][[1]](https://shared.com)"
        first_citation_start = report.index(first_citation)
        second_citation_start = report.index(second_citation)
        citation_len = len(first_citation)
        rewritten_text = "rewritten first segment."

        with patch.object(synonym_rewriter, "_generate_synonym_rewrite_text", new_callable=AsyncMock) as mock_generate:
            mock_generate.return_value = rewritten_text

            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "expand",
                    "selected_text": selected,
                    "start_offset": start,
                    "end_offset": end,
                    "user_instruction": "",
                },
                report_content=report,
                language="en-US",
            )

        assert second_citation in result["new_report"]

    @pytest.mark.asyncio
    async def test_synonym_rewrite_preserves_inference_messages(self, synonym_rewriter):
        """验证同义改写会保留未选中 inference 元数据与 citation 元数据。"""
        trailing_citation = "[checked_citation:0][[4]](https://d.com)"
        report = (
            "前缀[已选结论](#inference:1)中段"
            "[保留结论](#inference:10)"
            f"尾注{trailing_citation}"
        )
        selected = "[已选结论](#inference:1)中段"
        start = report.index(selected)
        end = start + len(selected)
        trailing_citation_start = report.index(trailing_citation)
        trailing_citation_end = trailing_citation_start + len(trailing_citation)
        infer_messages = [
            {"id": 1, "content": "已选结论"},
            {"id": 10, "content": "保留结论"},
        ]

        with patch.object(synonym_rewriter, "_generate_synonym_rewrite_text", new_callable=AsyncMock) as mock_generate:
            mock_generate.return_value = "改写后"

            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "expand",
                    "selected_text": selected,
                    "start_offset": start,
                    "end_offset": end,
                    "user_instruction": "",
                },
                report_content=report,
                language="zh-CN",
            )

        assert "[已选结论](#inference:1)" not in result["new_report"]
        assert "[保留结论](#inference:10)" in result["new_report"]

    @pytest.mark.asyncio
    async def test_synonym_rewrite_preserves_metadata_after_checked_citation_cleanup(self, synonym_rewriter):
        """验证清理 checked citation 后,元数据仍按约定原样透传。"""
        report = "前缀[checked_citation:0][[1]](https://a.com)选中内容[保留结论](#inference:3)后缀"
        start_offset = report.index("选中内容")
        end_offset = report.index("后缀")
        citation_messages = {"code": 0, "msg": "success", "data": [{"id": 0, "reference_index": 1}]}
        with patch.object(
            synonym_rewriter,
            "_generate_synonym_rewrite_text",
            new_callable=AsyncMock,
            return_value="新的选中内容",
        ):
            result = await synonym_rewriter.synonym_rewrite(
                feedback={
                    "action": "expand",
                    "selected_text": report[start_offset:end_offset],
                    "start_offset": start_offset,
                    "end_offset": end_offset,
                },
                report_content=report,
                language="zh-CN",
            )

        assert result["new_report"] == "前缀[checked_citation:0][[1]](https://a.com)新的选中内容后缀"