"""向量存储单元测试"""

import os
import pytest
from storage.vector_db import OpenGaussVectorDB, VectorRecord


class TestVectorRecord:
    """向量记录数据类测试"""

    def test_create_record(self):
        """测试创建记录"""
        record = VectorRecord(
            id="test-1",
            vector=[0.1, 0.2, 0.3],
            metadata={"source": "test"},
            text="test content",
        )

        assert record.id == "test-1"
        assert len(record.vector) == 3
        assert record.metadata["source"] == "test"
        assert record.text == "test content"

    def test_create_record_empty_metadata(self):
        """测试空 metadata"""
        record = VectorRecord(
            id="id-1",
            vector=[0.0, 1.0],
            metadata={},
            text="text",
        )
        assert record.metadata == {}
        assert record.id == "id-1"

    def test_create_record_metadata_types(self):
        """测试 metadata 支持多种值类型"""
        record = VectorRecord(
            id="id-2",
            vector=[0.1] * 4,
            metadata={
                "str": "a",
                "int": 1,
                "float": 1.5,
                "bool": True,
                "list": [1, 2],
            },
            text="content",
        )
        assert record.metadata["str"] == "a"
        assert record.metadata["int"] == 1
        assert record.metadata["float"] == 1.5
        assert record.metadata["bool"] is True
        assert record.metadata["list"] == [1, 2]


@pytest.mark.skipif(
    not os.environ.get("TEST_DB_PASSWORD"), reason="Test database not configured"
)
class TestOpenGaussVectorDBUnit:
    """openGauss 向量存储单元测试"""

    @pytest.fixture
    def vector_db(self, test_config):
        """向量数据库 fixture"""
        db = OpenGaussVectorDB(
            host=test_config["db_host"],
            port=test_config["db_port"],
            database=test_config["db_name"],
            user=test_config["db_user"],
            password=test_config["db_password"],
            sslmode=test_config.get("db_sslmode"),
            gssencmode=test_config.get("db_gssencmode"),
            table_name="test_unit_vectors",
            dimension=1536,
        )
        yield db
        db.clear()

    def test_connection(self, vector_db):
        """测试数据库连接"""
        assert vector_db._connection is not None
        assert not vector_db._is_connection_closed()

    def test_insert_single_record(self, vector_db):
        """测试插入单个记录"""
        record = VectorRecord(
            id="test-1",
            vector=[0.1] * 1536,
            metadata={"test": True},
            text="test content",
        )

        count = vector_db.insert([record])

        assert count == 1
        assert vector_db.count() == 1

    def test_insert_batch_records(self, vector_db):
        """测试批量插入"""
        records = [
            VectorRecord(
                id=f"test-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"index": i},
                text=f"content {i}",
            )
            for i in range(10)
        ]

        count = vector_db.insert(records)

        assert count == 10
        assert vector_db.count() == 10

    def test_insert_duplicate(self, vector_db):
        """测试插入重复记录(应该更新)"""
        record = VectorRecord(
            id="duplicate-1",
            vector=[0.1] * 1536,
            metadata={"version": 1},
            text="original content",
        )

        vector_db.insert([record])

        updated_record = VectorRecord(
            id="duplicate-1",
            vector=[0.2] * 1536,
            metadata={"version": 2},
            text="updated content",
        )

        vector_db.insert([updated_record])

        retrieved = vector_db.get_by_id("duplicate-1")
        assert retrieved["metadata"]["version"] == 2
        assert retrieved["text"] == "updated content"

    def test_get_by_id(self, vector_db):
        """测试根据ID获取记录"""
        record = VectorRecord(
            id="get-test-1", vector=[0.1] * 1536, metadata={}, text="test content"
        )

        vector_db.insert([record])
        retrieved = vector_db.get_by_id("get-test-1")

        assert retrieved is not None
        assert retrieved["id"] == "get-test-1"
        assert retrieved["text"] == "test content"

    def test_get_by_id_not_found(self, vector_db):
        """测试获取不存在的ID"""
        result = vector_db.get_by_id("non-existent-id")

        assert result is None

    def test_delete_records(self, vector_db):
        """测试删除记录"""
        records = [
            VectorRecord(
                id=f"delete-{i}",
                vector=[i * 0.001] * 1536,
                metadata={},
                text=f"content {i}",
            )
            for i in range(5)
        ]

        vector_db.insert(records)

        deleted_count = vector_db.delete(["delete-1", "delete-2"])

        assert deleted_count == 2
        assert vector_db.count() == 3

    def test_delete_empty_list(self, vector_db):
        """测试删除空列表"""
        deleted_count = vector_db.delete([])

        assert deleted_count == 0

    def test_clear_all(self, vector_db):
        """测试清空所有数据"""
        records = [
            VectorRecord(
                id=f"clear-{i}",
                vector=[i * 0.001] * 1536,
                metadata={},
                text=f"content {i}",
            )
            for i in range(5)
        ]

        vector_db.insert(records)
        assert vector_db.count() == 5

        vector_db.clear()

        assert vector_db.count() == 0

    def test_count_empty(self, vector_db):
        """测试空数据库计数"""
        assert vector_db.count() == 0

    def test_count_with_data(self, vector_db):
        """测试有数据时的计数"""
        records = [
            VectorRecord(
                id=f"count-{i}",
                vector=[i * 0.001] * 1536,
                metadata={},
                text=f"content {i}",
            )
            for i in range(3)
        ]

        vector_db.insert(records)

        assert vector_db.count() == 3

    def test_vector_search(self, vector_db):
        """测试向量搜索"""
        records = [
            VectorRecord(
                id=f"search-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"tag": f"tag-{i % 3}"},
                text=f"content about tag-{i % 3}",
            )
            for i in range(10)
        ]

        vector_db.insert(records)

        query_vector = [0.5] * 1536
        results = vector_db.search(query_vector=query_vector, limit=5, use_hybrid=False)

        assert len(results) <= 5
        assert all("score" in r for r in results)

    def test_search_with_filters(self, vector_db):
        """测试带过滤的搜索"""
        records = [
            VectorRecord(
                id=f"filter-{i}",
                vector=[i * 0.001] * 1536,
                metadata={"category": f"cat-{i % 2}"},
                text=f"content {i}",
            )
            for i in range(10)
        ]

        vector_db.insert(records)

        query_vector = [0.5] * 1536
        results = vector_db.search(
            query_vector=query_vector,
            limit=10,
            filters={"category": "cat-0"},
            use_hybrid=False,
        )

        assert all(r["metadata"]["category"] == "cat-0" for r in results)

    def test_search_empty_database(self, vector_db):
        """测试空数据库搜索"""
        results = vector_db.search(query_vector=[0.1] * 1536, limit=5, use_hybrid=False)

        assert len(results) == 0

    def test_close_connection(self, vector_db):
        """测试关闭连接"""
        vector_db.close()

        assert vector_db._is_connection_closed()