"""向量存储集成测试"""

import os
import pytest
from storage.vector_db import OpenGaussVectorDB, VectorRecord


@pytest.mark.skipif(
    not os.environ.get("TEST_DB_PASSWORD"), reason="Test database not configured"
)
class TestOpenGaussVectorDBIntegration:
    """openGauss 向量存储集成测试"""

    @pytest.fixture
    def vector_db(self, test_config):
        """向量数据库 fixture"""
        db = OpenGaussVectorDB(
            host=test_config["db_host"],
            port=test_config["db_port"],
            database=test_config["db_name"],
            user=test_config["db_user"],
            password=test_config["db_password"],
            sslmode=test_config.get("db_sslmode"),
            gssencmode=test_config.get("db_gssencmode"),
            table_name="test_integration_vectors",
            dimension=1536,
        )
        yield db
        db.clear()

    def test_insert_and_search(self, vector_db):
        """测试插入和搜索完整流程"""
        records = [
            VectorRecord(
                id=f"test-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"index": i, "tag": f"tag-{i % 3}"},
                text=f"content {i}",
            )
            for i in range(20)
        ]

        vector_db.insert(records)

        query_vector = [0.5] * 1536
        results = vector_db.search(query_vector=query_vector, limit=5, use_hybrid=False)

        assert len(results) <= 5
        assert all("score" in r for r in results)

    def test_hybrid_search_workflow(self, vector_db, sample_markdown):
        """测试混合搜索工作流"""
        records = [
            VectorRecord(
                id=f"hybrid-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"section": f"section-{i % 2}"},
                text=f"This is content about section-{i % 2} with keyword Python.",
            )
            for i in range(15)
        ]

        vector_db.insert(records)

        results = vector_db.search(
            query_vector=[0.001] * 1536, query_text="Python", limit=5, use_hybrid=True
        )

        assert len(results) > 0
        assert all("score" in r for r in results)

    def test_bm25_search(self, vector_db):
        """测试 BM25 全文搜索"""
        records = [
            VectorRecord(
                id=f"bm25-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"source": "test"},
                text=f"This document mentions Python programming language and related concepts.",
            )
            for i in range(10)
        ]

        vector_db.insert(records)

        results = vector_db._bm25_search(
            vector_db._get_conn().cursor(), "Python programming", limit=5, filters=None
        )

        assert len(results) > 0
        assert all("score" in r for r in results)

    def test_rrf_reranking(self, vector_db):
        """测试 RRF 重排序"""
        vector_results = [
            {"id": f"v-{i}", "score": 0.9 - i * 0.05, "text": f"vector {i}"}
            for i in range(10)
        ]

        bm25_results = [
            {"id": f"b-{i}", "score": 0.8 - i * 0.05, "text": f"bm25 {i}"}
            for i in range(10)
        ]

        fused = vector_db._rrf_rerank(vector_results, bm25_results, limit=5)

        assert len(fused) <= 5
        assert all("score" in r for r in fused)

    def test_full_workflow(self, vector_db, sample_markdown):
        """测试完整工作流:插入→搜索→删除"""
        records = [
            VectorRecord(
                id=f"workflow-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"step": f"step-{i}"},
                text=f"Workflow step {i} content",
            )
            for i in range(5)
        ]

        vector_db.insert(records)

        results = vector_db.search(query_vector=[0.5] * 1536, limit=3, use_hybrid=False)

        assert len(results) > 0

        deleted_count = vector_db.delete([r["id"] for r in results])

        assert deleted_count == len(results)
        assert vector_db.count() == 2