"""大模型提供商:Text Embedding Inference"""
import logging
import httpx
from typing_extensions import override
from apps.models import LLMType
from .base import BaseProvider
_logger = logging.getLogger(__name__)
class TEIProvider(BaseProvider):
"""Text Embedding Inference"""
@override
def _check_type(self) -> None:
"""检查模型能力是否包含Embedding"""
if LLMType.EMBEDDING not in self.config.llmType:
err = "模型能力不包含Embedding"
_logger.error(err)
raise RuntimeError(err)
def _validate_input(self, text: list[str]) -> list[str]:
"""验证待向量化文本格式是否正确"""
if not text:
err = "待向量化文本不能为空"
_logger.error(err)
raise ValueError(err)
filtered_text = [t for t in text if t.strip() != ""]
if not filtered_text:
err = "待向量化文本全为空字符串"
_logger.error(err)
raise ValueError(err)
return filtered_text
@override
def _init_client(self) -> None:
"""初始化模型API客户端"""
self._api_url = self.config.baseUrl + "/embed"
self._headers = {
"Content-Type": "application/json",
}
if self.config.apiKey:
self._headers["Authorization"] = f"Bearer {self.config.apiKey}"
@override
async def embedding(self, text: list[str]) -> list[list[float]]:
"""访问TEI兼容的Embedding API,获得向量化数据"""
text = self._validate_input(text)
async with httpx.AsyncClient(verify=False) as client:
result = []
for single_text in text:
data = {
"inputs": single_text,
"normalize": True,
}
response = await client.post(
self._api_url, json=data, headers=self._headers, timeout=60.0,
)
json = response.json()
result.append(json[0])
return result