import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from openjiuwen_deepsearch.algorithm.user_feedback_processor.action_definitions import (
    ResolvedUserAction,
    SupplementarySearchActionSubcategory,
    UserFeedbackActionCategory,
    SynonymRewriteActionSubcategory,
    UserFeedbackRewriteStreamResult,
)
from openjiuwen_deepsearch.common.exception import CustomRuntimeException, CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.algorithm.user_feedback_processor.user_feedback_processor import (
    UserFeedbackProcessor,
)
from openjiuwen_deepsearch.utils.common_utils.stream_utils import StreamEvent
from openjiuwen_deepsearch.utils.constants_utils.node_constants import NodeId


class TestParseFeedback:
    @pytest.mark.parametrize(
        ("raw_input", "expected_action", "expected_selected_text"),
        [
            (
                json.dumps(
                    {
                        "action": "expand",
                        "selected_text": "原文",
                        "start_offset": 10,
                        "end_offset": 12,
                        "user_instruction": "扩写",
                    }
                ),
                "expand",
                "原文",
            ),
            (
                json.dumps(
                    {
                        "action": "expand",
                        "selected_text": "原文",
                        "start_offset": 10,
                        "end_offset": 12,
                        "user_instruction": "扩写",
                        "rewrite_scope": "",
                    }
                ),
                "expand",
                "原文",
            ),
            (json.dumps({"action": "finish"}), "finish", None),
        ],
    )
    def test_parse_valid_requests(self, raw_input, expected_action, expected_selected_text):
        data = UserFeedbackProcessor.parse_feedback(raw_input)
        assert data["action"] == expected_action
        assert data["rewrite_scope"] == "selected_only"
        if expected_selected_text is not None:
            assert data["selected_text"] == expected_selected_text

    def test_parse_valid_sync_request(self):
        data = UserFeedbackProcessor.parse_feedback(
            json.dumps({"action": "sync", "selected_text": "前端完整报告"}, ensure_ascii=False)
        )

        assert data["action"] == "sync"
        assert data["selected_text"] == "前端完整报告"
        assert data["rewrite_scope"] == "selected_only"

    @pytest.mark.parametrize(
        "action_value",
        [
            pytest.param(None, id="missing"),
            pytest.param("", id="empty_string"),
        ],
    )
    def test_parse_feedback_rejects_invalid_action(self, action_value):
        payload = {
            "selected_text": "原文",
            "start_offset": 10,
            "end_offset": 12,
            "user_instruction": "补充这一段的信息",
        }
        if action_value is not None:
            payload["action"] = action_value
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.parse_feedback(json.dumps(payload))

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code

    @pytest.mark.parametrize(
        ("raw_input", "expected_error_code", "message_fragment"),
        [
            ("not json", StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_JSON.code, "Expecting value"),
            ("1", StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_JSON.code, "expected JSON object"),
        ],
    )
    def test_parse_invalid_requests(self, raw_input, expected_error_code, message_fragment):
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.parse_feedback(raw_input)

        assert exc_info.value.error_code == expected_error_code
        assert message_fragment in exc_info.value.message


class TestValidate:
    def test_valid_input(self):
        report_content = "0123456789原文0123456789"
        feedback = {
            "action": "expand",
            "selected_text": "原文",
            "start_offset": 10,
            "end_offset": 12,
        }

        assert UserFeedbackProcessor.validate(feedback, report_content) is None

    @pytest.mark.parametrize(
        ("feedback", "report_content", "expected_error_code"),
        [
            (
                {
                    "action": "expand",
                    "selected_text": "不匹配的文本",
                    "start_offset": 0,
                    "end_offset": 6,
                },
                "实际的报告内容",
                StatusCode.USER_FEEDBACK_PROCESSOR_OFFSET_MISMATCH.code,
            ),
            (
                {
                    "action": "unknown",
                    "selected_text": "text",
                    "start_offset": 0,
                    "end_offset": 4,
                },
                "text",
                StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code,
            ),
            (
                {
                    "action": "expand",
                    "selected_text": None,
                    "start_offset": "0",
                    "end_offset": 4,
                },
                "text",
                StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code,
            ),
            (
                {
                    "action": "expand",
                    "selected_text": "text",
                    "start_offset": 0,
                    "end_offset": 4,
                    "user_instruction": ["not", "a", "string"],
                },
                "text",
                StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code,
            ),
            (
                {
                    "action": "expand",
                    "selected_text": "text",
                    "start_offset": 0,
                    "end_offset": 4,
                    "rewrite_scope": 123,
                },
                "text",
                StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code,
            ),
            (
                {
                    "action": "supplementary_search",
                    "selected_text": "text",
                    "start_offset": 0,
                    "end_offset": 4,
                    "rewrite_scope": 123,
                },
                "text",
                StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code,
            ),
        ],
    )
    def test_invalid_input(self, feedback, report_content, expected_error_code):
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate(feedback, report_content)

        assert exc_info.value.error_code == expected_error_code

    def test_finish_action_skips_offset_validation(self):
        assert UserFeedbackProcessor.validate({"action": "finish"}, "any report content") is None

    def test_validate_sync_skips_offset_validation(self):
        feedback = {"action": "sync", "selected_text": "前端完整报告"}
        assert UserFeedbackProcessor.validate(feedback, "旧报告") is None

    def test_validate_sync_requires_selected_text(self):
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate({"action": "sync"}, "旧报告")

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code
        assert "selected_text" in exc_info.value.message

    def test_validate_sync_rejects_empty_selected_text(self):
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate({"action": "sync", "selected_text": ""}, "旧报告")

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code
        assert "selected_text" in exc_info.value.message

    def test_validate_expand_no_longer_rejects_long_selection(self):
        report_content = "a" * 5000
        feedback = {
            "action": "expand",
            "selected_text": report_content,
            "start_offset": 0,
            "end_offset": len(report_content),
        }
        assert UserFeedbackProcessor.validate(feedback, report_content) is None

    @pytest.mark.parametrize(
        "extra_fields",
        [
            {"user_instruction": 123},
            {"rewrite_scope": 123},
        ],
    )
    def test_finish_action_rejects_non_string_optional_fields(self, extra_fields):
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate(
                {"action": "finish", **extra_fields},
                "any report content",
            )
        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_PARAM_TYPE.code

    def test_validate_rejects_supplementary_search_with_invalid_rewrite_scope(self):
        report_content = "原文"
        feedback = {
            "action": "supplementary_search",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
            "rewrite_scope": "invalid_scope",
        }
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate(feedback, report_content)

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_REWRITE_SCOPE.code

    def test_validate_accepts_expand_with_non_enum_rewrite_scope(self):
        report_content = "0123456789原文0123456789"
        feedback = {
            "action": "expand",
            "selected_text": "原文",
            "start_offset": 10,
            "end_offset": 12,
            "rewrite_scope": "ignored_for_non_supplementary",
        }
        assert UserFeedbackProcessor.validate(feedback, report_content) is None

    def test_validate_rejects_scope_encoded_supplementary_action(self):
        report_content = "原文"
        feedback = {
            "action": "supplementary_search_selected_and_related",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
        }
        with pytest.raises(CustomValueException) as exc_info:
            UserFeedbackProcessor.validate(feedback, report_content)

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code

class TestUserFeedbackProcessorDispatch:
    @pytest.fixture
    def processor(self):
        return UserFeedbackProcessor(llm_model_name="mock_model")

    @pytest.mark.asyncio
    async def test_execute_dispatches_rewrite_actions_to_synonym_rewrite_service(self, processor):
        feedback = {
            "action": "expand",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
            "user_instruction": "",
        }

        with patch.object(processor._synonym_rewriter, "synonym_rewrite", new_callable=AsyncMock) as mock_synonym_rewrite:
            mock_synonym_rewrite.return_value = {
                "new_report": "改写后的文本后续内容",
                "original_text": "原文",
                "original_start_offset": 0,
                "original_end_offset": 2,
                "original_text_clean": "原文",
                "rewritten_text": "改写后的文本",
                "rewritten_start_offset": 0,
                "rewritten_end_offset": 6,
            }

            result = await processor.execute(
                feedback=feedback,
                final_result={
                    "response_content": "原文后续内容",
                    "citation_messages": {},
                    "infer_messages": [],
                },
                language="zh-CN",
            )

        assert result == {
            "new_report": "改写后的文本后续内容",
            "original_text": "原文",
            "original_start_offset": 0,
            "original_end_offset": 2,
            "original_text_clean": "原文",
            "rewritten_text": "改写后的文本",
            "rewritten_start_offset": 0,
            "rewritten_end_offset": 6,
        }
        mock_synonym_rewrite.assert_awaited_once_with(
            feedback=feedback,
            report_content="原文后续内容",
            language="zh-CN",
        )

    @pytest.mark.asyncio
    async def test_execute_rejects_unsupported_action(self, processor):
        with pytest.raises(CustomValueException) as exc_info:
            await processor.execute(
                feedback={"action": "unsupported"},
                final_result={"response_content": "report"},
                language="zh-CN",
            )

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code

    @pytest.mark.asyncio
    async def test_execute_dispatches_supplementary_search_to_service(self, processor):
        feedback = {
            "action": "supplementary_search",
            "rewrite_scope": "selected_only",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
            "user_instruction": "补充这一段的信息",
        }
        final_result = {
            "response_content": "原文后续内容",
            "citation_messages": {},
            "infer_messages": [],
        }

        with patch.object(
            processor._supplementary_searcher,
            "supplementary_search",
            new_callable=AsyncMock,
        ) as mock_supplementary_search:
            mock_supplementary_search.return_value = {
                "new_report": "原文后续内容",
                "original_text": "原文",
                "original_start_offset": 0,
                "original_end_offset": 2,
                "original_text_clean": "原文",
                "rewritten_text": "## 第二章\n新章节内容",
                "rewritten_start_offset": 0,
                "rewritten_end_offset": 10,
            }

            result = await processor.execute(
                feedback=feedback,
                final_result=final_result,
                language="zh-CN",
            )

        mock_supplementary_search.assert_awaited_once_with(
            feedback=feedback,
            final_result=final_result,
            language="zh-CN",
        )

        assert result["rewritten_text"] == "## 第二章\n新章节内容"

    @pytest.mark.asyncio
    async def test_execute_dispatches_new_task_to_processor(self, processor):
        feedback = {
            "action": "new_task",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
            "user_instruction": "补充行业背景",
        }
        final_result = {
            "response_content": "原文后续内容",
            "citation_messages": {},
            "infer_messages": [],
        }

        with patch.object(
            processor._new_task_processor,
            "run_new_task",
            new_callable=AsyncMock,
        ) as mock_run_new_task:
            mock_run_new_task.return_value = {
                "new_report": "## 第一章\n新章节内容",
                "original_text": "## 第一章\n旧章节内容",
                "original_start_offset": 0,
                "original_end_offset": 11,
                "original_text_clean": "## 第一章\n旧章节内容",
                "rewritten_text": "## 第一章\n新章节内容",
                "rewritten_start_offset": 0,
                "rewritten_end_offset": 11,
                "section_start_offset": 0,
                "section_end_offset": 11,
                "section_title": "第一章",
                "matched_section_id": "1",
                "match_mode": "title_exact",
                "assessment_summary": "历史资料足够",
                "used_historical_doc_count": 2,
                "used_new_doc_count": 0,
                "missing_aspects": [],
            }

            result = await processor.execute(
                feedback=feedback,
                final_result=final_result,
                language="zh-CN",
            )

        assert result["section_title"] == "第一章"
        mock_run_new_task.assert_awaited_once_with(
            feedback=feedback,
            final_result=final_result,
            language="zh-CN",
        )

    @pytest.mark.asyncio
    async def test_execute_sync_returns_updated_report_without_touching_metadata(self, processor):
        citation_messages = {"code": 0, "msg": "success", "data": [{"id": 0}]}
        infer_messages = [{"id": 9, "content": "保留"}]
        result = await processor.execute(
            feedback={"action": "sync", "selected_text": "用户改后的完整报告"},
            final_result={
                "response_content": "旧报告",
                "citation_messages": citation_messages,
                "infer_messages": infer_messages,
            },
            language="zh-CN",
        )

        assert result == {
            "sync_only": True,
            "new_report": "用户改后的完整报告",
            "original_text": "旧报告",
            "original_start_offset": 0,
            "original_end_offset": 3,
            "rewritten_text": "用户改后的完整报告",
            "rewritten_start_offset": 0,
            "rewritten_end_offset": 9,
        }

    def test_build_stream_result_returns_none_for_non_rewrite_action(self):
        feedback = {
            "action": "finish",
            "selected_text": "",
            "start_offset": 0,
            "end_offset": 0,
        }
        action_result = {
            "rewritten_text": "",
            "start_offset": 0,
            "new_end_offset": 0,
        }

        assert UserFeedbackProcessor.build_stream_result(feedback, action_result) is None

    def test_build_stream_result_returns_none_for_sync_action(self):
        assert UserFeedbackProcessor.build_stream_result(
            {"action": "sync", "selected_text": "完整报告"},
            {"sync_only": True, "new_report": "完整报告"},
        ) is None

    def test_build_stream_result_builds_synonym_payload_from_action_result(self):
        feedback = {
            "action": "expand",
            "selected_text": "前端原文",
            "start_offset": 100,
            "end_offset": 104,
        }
        action_result = {
            "original_text": "执行结果原文",
            "original_start_offset": 3,
            "original_end_offset": 7,
            "rewritten_text": "执行结果改写",
            "rewritten_start_offset": 3,
            "rewritten_end_offset": 9,
        }

        result = UserFeedbackProcessor.build_stream_result(feedback, action_result)

        assert result == UserFeedbackRewriteStreamResult(
            original_text="执行结果原文",
            original_start_offset=3,
            original_end_offset=7,
            rewritten_text="执行结果改写",
            rewritten_start_offset=3,
            rewritten_end_offset=9,
            action_category=UserFeedbackActionCategory.SYNONYM_REWRITE,
            action_subcategory=SynonymRewriteActionSubcategory.EXPAND,
        )

    def test_build_stream_result_builds_local_edit_payload_for_supplementary_search(self):
        feedback = {
            "action": "supplementary_search",
            "rewrite_scope": "selected_only",
            "selected_text": "原文",
            "start_offset": 3,
            "end_offset": 5,
        }
        action_result = {
            "original_text": "## 第二章\n旧章节内容",
            "original_start_offset": 0,
            "original_end_offset": 11,
            "rewritten_text": "## 第二章\n新章节内容",
            "rewritten_start_offset": 0,
            "rewritten_end_offset": 11,
        }

        result = UserFeedbackProcessor.build_stream_result(feedback, action_result)

        assert result == UserFeedbackRewriteStreamResult(
            original_text="## 第二章\n旧章节内容",
            original_start_offset=0,
            original_end_offset=11,
            rewritten_text="## 第二章\n新章节内容",
            rewritten_start_offset=0,
            rewritten_end_offset=11,
            action_category=UserFeedbackActionCategory.SUPPLEMENTARY_SEARCH,
            action_subcategory=SupplementarySearchActionSubcategory.SUPPLEMENTARY_SEARCH,
        )

    def test_build_stream_result_builds_local_edit_payload_for_new_task(self):
        feedback = {
            "action": "new_task",
            "selected_text": "原文",
            "start_offset": 3,
            "end_offset": 5,
        }
        action_result = {
            "original_text": "## 第一章\n旧章节内容",
            "original_start_offset": 0,
            "original_end_offset": 11,
            "rewritten_text": "## 第一章\n新章节内容",
            "rewritten_start_offset": 0,
            "rewritten_end_offset": 11,
        }

        result = UserFeedbackProcessor.build_stream_result(feedback, action_result)

        assert result.action_category == UserFeedbackActionCategory.NEW_TASK
        assert result.rewritten_text == "## 第一章\n新章节内容"

    def test_build_stream_result_uses_rewrite_error_errmsg_for_invalid_rewrite_mapping(self):
        feedback = {
            "action": "expand",
            "selected_text": "原文",
            "start_offset": 0,
            "end_offset": 2,
        }
        action_result = {
            "original_text": "原文",
            "original_start_offset": 0,
            "original_end_offset": 2,
            "rewritten_text": "改写",
            "rewritten_start_offset": 0,
            "rewritten_end_offset": 2,
        }

        with patch(
            "openjiuwen_deepsearch.algorithm.user_feedback_processor.user_feedback_processor.resolve_feedback_action",
            return_value=ResolvedUserAction(
                action_category=UserFeedbackActionCategory.SYNONYM_REWRITE,
                action_subcategory=SupplementarySearchActionSubcategory.SUPPLEMENTARY_SEARCH,
            ),
        ):
            with pytest.raises(CustomRuntimeException) as exc_info:
                UserFeedbackProcessor.build_stream_result(feedback, action_result)

        assert exc_info.value.error_code == StatusCode.USER_FEEDBACK_PROCESSOR_REWRITE_ERROR.code
        assert exc_info.value.message == StatusCode.USER_FEEDBACK_PROCESSOR_REWRITE_ERROR.errmsg.format(
            e="Rewrite stream result requires synonym_rewrite subcategory, got supplementary_search"
        )


class TestSendResult:
    @pytest.mark.asyncio
    async def test_send_result_outputs_full_rewrite_metadata_with_updated_final_result_field(self):
        session = MagicMock()
        session.write_custom_stream = AsyncMock()

        feedback = {"action": "expand"}
        result = UserFeedbackRewriteStreamResult(
                original_text="原始文本",
                original_start_offset=10,
                original_end_offset=14,
                rewritten_text="改写后的文本",
                rewritten_start_offset=10,
                rewritten_end_offset=16,
                action_category=UserFeedbackActionCategory.SYNONYM_REWRITE,
                action_subcategory=SynonymRewriteActionSubcategory.EXPAND,
            )
        final_result = {
            "response_content": "完整改写后的报告",
            "citation_messages": {"code": 0, "msg": "success", "data": []},
            "infer_messages": [{"id": 0, "content": "保留推理"}],
            "exception_info": "[212405] stale error",
            "warning_info": "ignored",
        }

        await UserFeedbackProcessor.send_result(
            session=session,
            feedback=feedback,
            result=result,
            final_result=final_result,
        )

        session.write_custom_stream.assert_awaited_once()
        payload = session.write_custom_stream.await_args.args[0]
        assert payload["agent"] == NodeId.USER_FEEDBACK_PROCESSOR.value
        assert payload["event"] == StreamEvent.SUMMARY_RESPONSE.value

        content = json.loads(payload["content"])
        assert content == {
            "original_text": "原始文本",
            "original_start_offset": 10,
            "original_end_offset": 14,
            "rewritten_text": "改写后的文本",
            "rewritten_start_offset": 10,
            "rewritten_end_offset": 16,
            "action_category": "synonym_rewrite",
            "action_subcategory": "expand",
            "feedback_interaction_count": 0,
            "final_result": {
                "response_content": "完整改写后的报告",
                "citation_messages": {"code": 0, "msg": "success", "data": []},
                "infer_messages": [{"id": 0, "content": "保留推理"}],
                "exception_info": "[212405] stale error",
                "warning_info": "ignored",
            },
        }

    @pytest.mark.asyncio
    async def test_send_result_noops_for_finish_category(self):
        session = MagicMock()
        session.write_custom_stream = AsyncMock()

        await UserFeedbackProcessor.send_result(
            session=session,
            feedback={"action": "finish"},
            result=None,
            final_result=None,
        )

        session.write_custom_stream.assert_not_awaited()

    @pytest.mark.asyncio
    async def test_send_result_outputs_supplementary_search_metadata(self):
        session = MagicMock()
        session.write_custom_stream = AsyncMock()

        feedback = {"action": "supplementary_search"}
        result = UserFeedbackRewriteStreamResult(
            original_text="原始文本",
            original_start_offset=0,
            original_end_offset=4,
            rewritten_text="## 第二章\n新章节内容",
            rewritten_start_offset=0,
            rewritten_end_offset=11,
            action_category=UserFeedbackActionCategory.SUPPLEMENTARY_SEARCH,
            action_subcategory=SupplementarySearchActionSubcategory.SUPPLEMENTARY_SEARCH,
        )

        await UserFeedbackProcessor.send_result(
            session=session,
            feedback=feedback,
            result=result,
            final_result={
                "response_content": "新报告",
                "citation_messages": {"data": []},
                "infer_messages": [],
            },
        )

        payload = session.write_custom_stream.await_args.args[0]
        content = json.loads(payload["content"])
        assert content["action_category"] == "supplementary_search"
        assert content["action_subcategory"] == "supplementary_search"
        assert content["final_result"]["response_content"] == "新报告"

    @pytest.mark.asyncio
    async def test_send_result_outputs_lightweight_sync_ack(self):
        session = MagicMock()
        session.write_custom_stream = AsyncMock()

        await UserFeedbackProcessor.send_result(
            session=session,
            feedback={"action": "sync", "selected_text": "完整报告"},
            result=None,
        )

        payload = session.write_custom_stream.await_args.args[0]
        assert json.loads(payload["content"]) == {
            "action_category": "sync",
            "synced": True,
        }


class TestSendError:
    @pytest.mark.asyncio
    async def test_send_error_outputs_single_error_field_from_custom_exception(self):
        session = MagicMock()
        session.write_custom_stream = AsyncMock()
        error = CustomValueException(
            StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.code,
            StatusCode.USER_FEEDBACK_PROCESSOR_INVALID_ACTION.errmsg.format(action="bad"),
        )

        await UserFeedbackProcessor.send_error(session, error)

        payload = session.write_custom_stream.await_args.args[0]
        content = json.loads(payload["content"])
        assert content == {"error": str(error)}