"""嵌入提供者单元测试"""
import os
import pytest
from embeddings.base import EmbeddingBase, EmbeddingFactory
import embeddings.openai
class TestEmbeddingFactory:
"""嵌入工厂测试"""
def test_list_providers(self):
"""测试列出提供者"""
providers = EmbeddingFactory.list_providers()
assert "openai" in providers
assert isinstance(providers, list)
@pytest.mark.requires_openai
def test_create_openai_provider(self):
"""测试创建 OpenAI 提供者"""
embedder = EmbeddingFactory.create("openai")
assert isinstance(embedder, EmbeddingBase)
assert embedder.get_model_name()
assert embedder.get_dimension() > 0
@pytest.mark.requires_openai
def test_create_with_model_override(self):
"""测试使用模型覆盖创建提供者"""
embedder = EmbeddingFactory.create(
"openai",
model="text-embedding-3-large"
)
assert embedder.get_model_name() == "text-embedding-3-large"
assert embedder.get_dimension() == 3072
def test_invalid_provider(self):
"""测试无效的提供者"""
with pytest.raises(ValueError):
EmbeddingFactory.create("invalid_provider")
@pytest.mark.requires_openai
def test_create_with_explicit_dimension(self):
"""测试创建时显式传入 dimension 参数"""
embedder = EmbeddingFactory.create(
"openai",
model="text-embedding-v4",
dimension=1024,
)
assert embedder.get_model_name() == "text-embedding-v4"
assert embedder.get_dimension() == 1024
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="OPENAI_API_KEY not set"
)
class TestOpenAIEmbedding:
"""OpenAI 嵌入集成测试"""
@pytest.fixture
def embedder(self):
"""嵌入器 fixture"""
from embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding()
@pytest.mark.asyncio
async def test_encode_single(self, embedder):
"""测试单个文本编码"""
vector = await embedder.encode_single("test text")
dim = embedder.get_dimension()
assert len(vector) == dim
assert all(isinstance(v, float) for v in vector)
assert all(0.0 <= abs(v) <= 1.0 for v in vector)
@pytest.mark.asyncio
async def test_encode_batch(self, embedder):
"""测试批量编码"""
texts = ["text1", "text2", "text3"]
vectors = await embedder.encode(texts)
dim = embedder.get_dimension()
assert len(vectors) == 3
assert all(len(v) == dim for v in vectors)
assert all(all(isinstance(x, float) for x in v) for v in vectors)
@pytest.mark.asyncio
async def test_encode_empty_list(self, embedder):
"""测试空列表编码"""
vectors = await embedder.encode([])
assert vectors == []
@pytest.mark.asyncio
async def test_encode_special_characters(self, embedder):
"""测试特殊字符编码"""
vector = await embedder.encode_single("测试中文 🎉")
assert len(vector) == embedder.get_dimension()
assert all(isinstance(v, float) for v in vector)
@pytest.mark.asyncio
async def test_encode_single_empty_string(self, embedder):
"""测试空字符串编码(应返回与维度长度一致的向量)"""
vector = await embedder.encode_single("")
assert len(vector) == embedder.get_dimension()
assert all(isinstance(v, float) for v in vector)
class TestOpenAIEmbeddingInit:
"""OpenAI 嵌入初始化与错误路径"""
def test_init_without_api_key_raises(self, monkeypatch):
"""未设置 API key 时应抛出 ValueError"""
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
from embeddings.openai import OpenAIEmbedding
with pytest.raises(ValueError) as exc_info:
OpenAIEmbedding(api_key="")
assert "OPENAI_API_KEY" in str(exc_info.value)
def test_init_with_explicit_api_key_ok(self, monkeypatch):
"""显式传入 api_key 和 model 时不依赖环境变量"""
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENAI_EMBEDDING_MODEL", raising=False)
monkeypatch.delenv("OPENAI_EMBEDDING_DIMENSION", raising=False)
from embeddings.openai import OpenAIEmbedding
embedder = OpenAIEmbedding(
api_key="sk-test-fake", model="text-embedding-3-small", dimension=1536
)
assert embedder.get_dimension() == 1536
assert embedder.get_model_name() == "text-embedding-3-small"