"""OpenAI 嵌入提供者"""
import os
from typing import List, Optional
from .base import EmbeddingBase, EmbeddingFactory
class OpenAIEmbedding(EmbeddingBase):
"""OpenAI文本嵌入实现"""
MODEL_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
"text-embedding-v4": 1024,
}
def __init__(
self,
model: str = None,
api_key: str = None,
base_url: str = None,
dimension: Optional[int] = None,
):
self.model = model or os.environ.get("OPENAI_EMBEDDING_MODEL") or "text-embedding-3-small"
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL")
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable is required")
dim_env = os.environ.get("OPENAI_EMBEDDING_DIMENSION")
model_from_env = os.environ.get("OPENAI_EMBEDDING_MODEL")
if dimension is not None:
self._dimension = dimension
elif dim_env and (model_from_env == self.model):
self._dimension = int(dim_env)
else:
self._dimension = self._get_model_dimension()
self._client = None
def _get_client(self):
"""延迟初始化客户端"""
if self._client is None:
try:
import openai
kwargs = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
self._client = openai.AsyncOpenAI(**kwargs)
except ImportError:
raise ImportError("Please install: pip install openai")
return self._client
def _get_model_dimension(self) -> int:
"""获取模型维度"""
if self.model in self.MODEL_DIMENSIONS:
return self.MODEL_DIMENSIONS[self.model]
try:
import openai
kwargs = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
client = openai.OpenAI(**kwargs)
response = client.embeddings.create(input=["test"], model=self.model)
return len(response.data[0].embedding)
except Exception as e:
raise ValueError(f"Failed to detect dimension for model {self.model}: {e}")
def get_dimension(self) -> int:
"""获取向量维度"""
return self._dimension
def get_model_name(self) -> str:
"""获取模型名称"""
return self.model
ENCODE_BATCH_SIZE = 10
async def encode(self, texts: List[str]) -> List[List[float]]:
"""批量编码文本(按批请求,每批最多 ENCODE_BATCH_SIZE 条)"""
if not texts:
return []
client = self._get_client()
results: List[List[float]] = []
try:
for i in range(0, len(texts), self.ENCODE_BATCH_SIZE):
batch = texts[i : i + self.ENCODE_BATCH_SIZE]
response = await client.embeddings.create(input=batch, model=self.model)
results.extend([item.embedding for item in response.data])
return results
except Exception as e:
raise RuntimeError(f"Failed to encode texts: {e}")
async def encode_single(self, text: str) -> List[float]:
"""编码单个文本"""
results = await self.encode([text])
return results[0] if results else []
EmbeddingFactory.register("openai", OpenAIEmbedding)