"""向量嵌入接口"""
from abc import ABC, abstractmethod
from typing import List
class EmbeddingBase(ABC):
"""嵌入模型基类"""
@abstractmethod
def get_dimension(self) -> int:
"""获取向量维度"""
pass
@abstractmethod
def get_model_name(self) -> str:
"""获取模型名称"""
pass
@abstractmethod
async def encode(self, texts: List[str]) -> List[List[float]]:
"""将文本编码为向量"""
pass
@abstractmethod
async def encode_single(self, text: str) -> List[float]:
"""编码单个文本"""
pass
class EmbeddingFactory:
"""嵌入模型工厂"""
_providers = {}
@classmethod
def register(cls, name: str, provider_class):
"""注册嵌入提供者"""
cls._providers[name] = provider_class
@classmethod
def create(cls, name: str, **kwargs) -> EmbeddingBase:
"""创建嵌入实例"""
if name not in cls._providers:
raise ValueError(f"Unknown embedding provider: {name}")
return cls._providers[name](**kwargs)
@classmethod
def list_providers(cls) -> List[str]:
"""列出所有可用提供者"""
return list(cls._providers.keys())