import base64
import json
from unittest.mock import patch, AsyncMock

import pytest

from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer import SourceTracerInfer
from openjiuwen_deepsearch.algorithm.source_tracer_infer.infer_call_model import (
    GraphInfo,
)


class TestSourceTracerInfer:
    """Test cases for SourceTracerInfer core functionality."""

    def setup_method(self):
        """Set up test fixtures."""
        self.context = {
            "language": "zh-CN",
            "llm_model_name": "mock_model",
            "source_tracer_response": "test response",
            "conclusion_with_records": None,
        }
        self.source_tracer_infer = SourceTracerInfer(self.context)

    def test_init(self):
        """Test SourceTracerInfer initialization."""
        assert self.source_tracer_infer.context == self.context
        assert self.source_tracer_infer.language == "zh-CN"
        assert self.source_tracer_infer.model_name == "mock_model"
        assert self.source_tracer_infer.response == "test response"
        assert self.source_tracer_infer.conclusion_with_records is None
        assert self.source_tracer_infer.checker_infos == {
            "graph_infos": [],
            "search_records": [],
        }
        assert hasattr(self.source_tracer_infer, "node_number")
        assert hasattr(self.source_tracer_infer, "supplement_graph")

    def test_encode_html_to_base64_valid(self):
        """Test _encode_html_to_base64 with valid HTML."""
        html_content = "<html><body>test</body></html>"
        result = self.source_tracer_infer._encode_html_to_base64(html_content)

        # Verify it's valid base64
        decoded = base64.b64decode(result).decode("utf-8")
        assert decoded == html_content

    def test_encode_html_to_base64_invalid(self):
        """Test _encode_html_to_base64 with invalid encoding."""
        with patch("base64.b64encode", side_effect=Exception("Encoding error")):
            with pytest.raises(Exception):
                self.source_tracer_infer._encode_html_to_base64("test")

    @pytest.mark.asyncio
    async def test_get_conclusion_and_records_new(self):
        """Test get_conclusion_and_records when conclusion_with_records is None."""
        expected_result = [{"conclusion": "test", "search_records": []}]

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.ResearchInferPreprocess"
        ) as mock_preprocessor:
            mock_instance = AsyncMock()
            mock_instance.run.return_value = expected_result
            mock_preprocessor.return_value = mock_instance

            await self.source_tracer_infer.get_conclusion_and_records()

            assert self.source_tracer_infer.conclusion_with_records == expected_result

    @pytest.mark.asyncio
    async def test_extract_reference_no_results(self):
        """Test extract_reference when no valid references found."""
        datas = {
            "conclusion": ["test conclusion"],
            "search_records": [{"content": "test"}],
        }

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = []

            result = await self.source_tracer_infer.extract_reference(datas)

            assert result == {}

    @pytest.mark.asyncio
    async def test_extract_reference_valid(self):
        """Test extract_reference with valid data."""
        datas = {
            "conclusion": ["test conclusion"],
            "search_records": [
                {"content": "test content 1"},
                {"content": "test content 2"},
            ],
        }

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = [0, 1]  # Both records are valid

            result = await self.source_tracer_infer.extract_reference(datas)

            assert result["conclusion"] == "test conclusion"
            assert len(result["reference"]) == 2
            assert result["reference"][0]["id"] == 0
            assert result["reference"][0]["content"] == "test content 1"

    @pytest.mark.asyncio
    async def test_extract_reference_invalid_index(self):
        """Test extract_reference with invalid index in results."""
        datas = {
            "conclusion": ["test conclusion"],
            "search_records": [{"content": "test content"}],
        }

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = [5]  # Invalid index

            result = await self.source_tracer_infer.extract_reference(datas)

            assert result["conclusion"] == "test conclusion"
            assert len(result["reference"]) == 0

    @pytest.mark.asyncio
    async def test_infer_basic(self):
        """Test infer with basic data."""
        evidences = {
            "conclusion": "test conclusion",
            "reference": [{"content": "test"}],
        }

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = ["test inference"]

            result = await self.source_tracer_infer.infer(evidences)

            assert result["conclusion"] == "test conclusion"
            assert result["inference"] == "test inference"

    @pytest.mark.asyncio
    async def test_infer_empty_result(self):
        """Test infer with empty LLM result."""
        evidences = {"conclusion": "test conclusion", "reference": []}

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = []

            result = await self.source_tracer_infer.infer(evidences)

            assert result["conclusion"] == "test conclusion"
            assert result["inference"] == ""

    @pytest.mark.asyncio
    async def test_filter_invalid_infer_valid(self):
        """Test filter_invalid_infer with valid inference."""
        inferences = {"inference": "valid inference"}

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = "true"

            result = await self.source_tracer_infer.filter_invalid_infer(inferences)

            assert result == inferences

    @pytest.mark.asyncio
    async def test_filter_invalid_infer_empty_result(self):
        """Test filter_invalid_infer with empty/falsy result raises ValueError."""
        inferences = {"inference": "invalid inference"}

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = ""

            with pytest.raises(ValueError) as exc_info:
                await self.source_tracer_infer.filter_invalid_infer(inferences)

            assert "invalid inference" in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_filter_invalid_infer_false_result(self):
        """Test filter_invalid_infer with 'false' string result."""
        inferences = {"inference": "invalid inference"}

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = "false"

            result = await self.source_tracer_infer.filter_invalid_infer(inferences)

            assert result == inferences

    @pytest.mark.asyncio
    async def test_structured_infer_valid(self):
        """Test structured_infer with valid inference."""
        inference = {"inference": "test"}
        expected_result = [([0], "relation", 1)]

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = expected_result

            result = await self.source_tracer_infer.structured_infer(inference)

            assert result == expected_result

    @pytest.mark.asyncio
    async def test_structured_infer_empty_result(self):
        """Test structured_infer with empty result."""
        inference = {"inference": "test"}

        with patch(
            "openjiuwen_deepsearch.algorithm.source_tracer_infer.infer.call_model",
            new_callable=AsyncMock,
        ) as mock_call_model:
            mock_call_model.return_value = []

            with pytest.raises(ValueError) as exc_info:
                await self.source_tracer_infer.structured_infer(inference)

            assert "unstructured inference" in str(exc_info.value)

    def test_generate_html_basic(self):
        """Test generate_html with basic graph data."""
        checked_infer_graph = GraphInfo(
            [
                [[0], "relation", 1]
            ],  # structured_inference - this should be a list of tuples
            {
                0: {"label": "node0", "url": "https://www.example.com"},
                1: {"label": "node1"},
            },  # node_map
            [0],  # citation_index
            [1],  # conclusion_ids
        )

        result = self.source_tracer_infer.generate_html.run(checked_infer_graph)

        assert isinstance(result, str)
        assert "<html>" in result
        assert "</html>" in result

    def test_generate_html_multiple_heads(self):
        """Test generate_html with multiple head nodes."""
        checked_infer_graph = GraphInfo(
            [[[0, 1], "relation", 2]],
            {0: {"label": "node0"}, 1: {"label": "node1"}, 2: {"label": "node2"}},
            [0],
            [2],
        )

        result = self.source_tracer_infer.generate_html.run(checked_infer_graph)

        assert isinstance(result, str)
        assert "<html>" in result

    def test_mark_conclusion_in_report_basic(self):
        """Test mark_conclusion_in_report with basic data."""
        infer_messages = [{"id": 0, "conclusion": "test conclusion"}]
        conclusion_infos = [{"start_pos": 0, "end_pos": 5}]
        self.source_tracer_infer.response = "original"

        result = self.source_tracer_infer.mark_conclusion_in_report(
            infer_messages, conclusion_infos
        )

        assert "[test conclusion](#inference:0)" in result

    def test_mark_conclusion_in_report_multiple(self):
        """Test mark_conclusion_in_report with multiple conclusions."""
        infer_messages = [
            {"id": 0, "conclusion": "conclusion1"},
            {"id": 1, "conclusion": "conclusion2"},
        ]
        conclusion_infos = [
            {"start_pos": 0, "end_pos": 5},
            {"start_pos": 10, "end_pos": 15},
        ]
        self.source_tracer_infer.response = "original text here"

        result = self.source_tracer_infer.mark_conclusion_in_report(
            infer_messages, conclusion_infos
        )

        # The method processes conclusions sequentially and modifies response in place
        # This can cause position-based replacements to interfere with each other
        # Check that method completed without error and returned a string
        assert isinstance(result, str)
        assert len(result) > 0  # Result should not be empty

    def test_mark_conclusion_in_report_error(self):
        """Test mark_conclusion_in_report error handling."""
        infer_messages = [{"id": 0, "conclusion": "test"}]
        conclusion_infos = [{"start_pos": 0, "end_pos": 100}]  # Invalid position
        original_response = "original"
        self.source_tracer_infer.response = original_response

        result = self.source_tracer_infer.mark_conclusion_in_report(
            infer_messages, conclusion_infos
        )

        # The method only returns original response if there's an exception
        # Invalid position doesn't cause exception, it just slices the string
        # So we expect the method to complete normally
        assert result != original_response  # Should be modified

    @pytest.mark.asyncio
    async def test_async_run_success(self):
        """Test async_run with successful execution."""
        datas = {"conclusion": ["test"], "search_records": [{"content": "test"}]}

        with patch.object(
            self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
        ) as mock_extract:
            mock_extract.return_value = {"conclusion": "test", "reference": []}

            with patch.object(
                self.source_tracer_infer, "infer", new_callable=AsyncMock
            ) as mock_infer:
                mock_infer.return_value = {
                    "conclusion": "test",
                    "inference": "inference",
                }

                with patch.object(
                    self.source_tracer_infer,
                    "filter_invalid_infer",
                    new_callable=AsyncMock,
                ) as mock_filter:
                    mock_filter.return_value = {
                        "conclusion": "test",
                        "inference": "inference",
                    }

                    with patch.object(
                        self.source_tracer_infer,
                        "structured_infer",
                        new_callable=AsyncMock,
                    ) as mock_structured:
                        mock_structured.return_value = [([0], "relation", 1)]

                        with patch.object(
                            self.source_tracer_infer.node_number, "number_node"
                        ) as mock_number:
                            mock_number.return_value = (
                                [([0], "relation", 1)],
                                {0: {"label": "test"}},
                                [0],
                                [1],
                            )

                            with patch.object(
                                self.source_tracer_infer.supplement_graph,
                                "run",
                                new_callable=AsyncMock,
                            ) as mock_supplement:
                                mock_supplement.return_value = (
                                    [([0], "relation", 1)],
                                    {0: {"label": "test"}},
                                    [0],
                                    [1],
                                )

                                with patch.object(
                                    self.source_tracer_infer.generate_html, "run"
                                ) as mock_html:
                                    mock_html.return_value = "<html>test</html>"

                                    (
                                        infer_message,
                                        checked_graph,
                                    ) = await self.source_tracer_infer.async_run(datas)

                                    assert infer_message["conclusion"] == "test"
                                    assert infer_message["inference"] == "inference"
                                    assert "html_base64" in infer_message
                                    assert checked_graph is not None

    @pytest.mark.asyncio
    async def test_async_run_extract_reference_empty(self):
        """Test async_run when extract_reference returns empty."""
        datas = {"conclusion": [], "search_records": []}

        with patch.object(
            self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
        ) as mock_extract:
            mock_extract.return_value = {}

            infer_message, checked_graph = await self.source_tracer_infer.async_run(
                datas
            )

            assert infer_message == {}
            assert checked_graph is None

    @pytest.mark.asyncio
    async def test_async_run_error_handling(self):
        """Test async_run error handling."""
        datas = {"conclusion": ["test"], "search_records": [{"content": "test"}]}

        with patch.object(
            self.source_tracer_infer, "extract_reference", new_callable=AsyncMock
        ) as mock_extract:
            mock_extract.side_effect = Exception("Test error")

            infer_message, checked_graph = await self.source_tracer_infer.async_run(
                datas
            )

            assert infer_message == {}
            assert checked_graph is None

    @pytest.mark.asyncio
    async def test_run_success(self):
        """Test run with successful execution."""
        self.source_tracer_infer.conclusion_with_records = [
            {
                "conclusion": ["test"],
                "search_records": [{"content": "test"}],
                "start_pos": 0,
                "end_pos": 5,
            }
        ]

        with patch.object(
            self.source_tracer_infer, "async_run", new_callable=AsyncMock
        ) as mock_async_run:
            mock_async_run.return_value = (
                {
                    "id": 0,
                    "conclusion": "test",
                    "inference": "inference",
                    "html_base64": "base64",
                },
                ([([0], "relation", 1)], {0: {"label": "test"}}, [0], [1]),
            )

            with patch.object(
                self.source_tracer_infer, "mark_conclusion_in_report"
            ) as mock_mark:
                mock_mark.return_value = "marked response"

                (
                    response,
                    infer_messages,
                    checker_infos,
                    error,
                ) = await self.source_tracer_infer.run()

                assert response == "marked response"
                assert len(infer_messages) == 1
                assert infer_messages[0]["id"] == 0
                assert len(checker_infos["graph_infos"]) == 1
                assert len(checker_infos["search_records"]) == 1
                assert error is None

    @pytest.mark.asyncio
    async def test_run_no_conclusion_with_records(self):
        """Test run when conclusion_with_records is None."""

        with patch.object(
            self.source_tracer_infer,
            "get_conclusion_and_records",
            new_callable=AsyncMock,
        ) as mock_get:
            mock_get.return_value = None

            # When conclusion_with_records is None, the run method tries to iterate over None
            # causing "NoneType object is not iterable" error
            (
                response,
                infer_messages,
                checker_infos,
                error,
            ) = await self.source_tracer_infer.run()

            assert response == "test response"  # Original response
            assert infer_messages == []
            assert checker_infos["graph_infos"] == []
            assert checker_infos["search_records"] == []
            assert error is not None  # Should have an error about NoneType not iterable

    @pytest.mark.asyncio
    async def test_run_error_handling(self):
        """Test run error handling."""
        self.source_tracer_infer.conclusion_with_records = [
            {"conclusion": ["test"], "search_records": [{"content": "test"}]}
        ]

        with patch.object(
            self.source_tracer_infer, "async_run", new_callable=AsyncMock
        ) as mock_async_run:
            mock_async_run.side_effect = Exception("Test error")

            (
                response,
                infer_messages,
                checker_infos,
                error,
            ) = await self.source_tracer_infer.run()

            assert response == "test response"  # Original response
            assert infer_messages == []
            assert checker_infos["graph_infos"] == []
            assert checker_infos["search_records"] == []
            assert error == "Test error"

    @pytest.mark.asyncio
    async def test_run_empty_html_base64(self):
        """Test run filters out messages with empty html_base64."""
        self.source_tracer_infer.conclusion_with_records = [
            {"conclusion": ["test"], "search_records": [{"content": "test"}]}
        ]

        with patch.object(
            self.source_tracer_infer, "async_run", new_callable=AsyncMock
        ) as mock_async_run:
            mock_async_run.return_value = (
                {"id": 0, "html_base64": ""},  # Empty html_base64
                None,
            )

            (
                response,
                infer_messages,
                checker_infos,
                error,
            ) = await self.source_tracer_infer.run()

            assert len(infer_messages) == 0  # Should be filtered out
            assert len(checker_infos["graph_infos"]) == 0