"""
API Embedding Model Implementation
Universal HTTP embedding client implementation.
"""
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from itertools import chain
from typing import Any, List, Optional
import requests
from openjiuwen.core.common.exception.codes import StatusCode
from openjiuwen.core.common.exception.errors import build_error
from openjiuwen.core.common.logging import logger
from openjiuwen.core.common.security.ssl_utils import SslUtils
from openjiuwen.core.foundation.store.base_embedding import Embedding, EmbeddingConfig
from openjiuwen.core.retrieval.common.callbacks import BaseCallback
class APIEmbedding(Embedding):
"""
Universal HTTP embedding client:
- payload: {"model": <model_name>, "input": <text or list>} (can attach kwargs)
- headers: default application/json, optional Authorization: Bearer <api_key>
- returns support one of the following formats:
{"embedding": [...]}
{"embeddings": [...]}
{"data": [{"embedding": [...]}, ...]}
"""
_EMBEDDING_SSL_VERIFY = "EMBEDDING_SSL_VERIFY"
_EMBEDDING_SSL_CERT = "EMBEDDING_SSL_CERT"
def __init__(
self,
config: EmbeddingConfig,
timeout: int = 60,
max_retries: int = 3,
extra_headers: Optional[dict] = None,
max_batch_size: int = 8,
max_concurrent: int = 50,
):
self.config = config
self.model_name = config.model_name
self.api_key = config.api_key
self.max_batch_size = max_batch_size
self.api_url = config.base_url or ""
self.timeout = timeout
self.max_retries = max_retries
self.limiter = asyncio.Semaphore(max_concurrent)
self._headers = {"Content-Type": "application/json"}
if self.api_key:
self._headers["Authorization"] = f"Bearer {self.api_key}"
if extra_headers:
self._headers.update(extra_headers)
url_is_https = self.api_url.startswith("https://") if self.api_url else False
if url_is_https:
is_ssl_verify_off = SslUtils._bool_env(self._EMBEDDING_SSL_VERIFY, ["false"])
ssl_cert = os.getenv(self._EMBEDDING_SSL_CERT)
if is_ssl_verify_off:
self._verify_ssl = False
elif ssl_cert:
self._verify_ssl = ssl_cert
else:
self._verify_ssl = True
else:
self._verify_ssl = False
self._dimension: Optional[int] = None
self._max_concurrent = max_concurrent
self._executor: Optional[ThreadPoolExecutor] = None
def __del__(self):
"""Custom destructor to shutdown thread pool executor at delete"""
if getattr(self, "_executor", None) is not None:
try:
self._executor.shutdown(wait=False, cancel_futures=True)
except Exception:
self._executor = None
@property
def dimension(self) -> int:
"""Return embedding dimension.
Uses sync method to get dimension, safe to call from any context.
"""
if self._dimension is not None:
return self._dimension
embedding = self.embed_query_sync("test")
self._dimension = len(embedding)
logger.debug(f"Determined embedding dimension: {self._dimension}")
return self._dimension
@property
def executor(self) -> ThreadPoolExecutor:
"""Return thread pool executor for concurrent sync requests"""
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=self._max_concurrent, thread_name_prefix="openjiuwen_embed")
return self._executor
@staticmethod
def validate_embed_docs(texts: List[str | Any], kwargs: dict) -> list[str]:
"""Validate input for batch embedding request and return a list of non-empty documents"""
if not texts:
raise build_error(StatusCode.RETRIEVAL_EMBEDDING_INPUT_INVALID, error_msg="Empty texts list provided")
callback_cls = kwargs.get("callback_cls", BaseCallback)
if not isinstance(callback_cls, type) or not issubclass(callback_cls, BaseCallback):
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_CALLBACK_INVALID,
error_msg=(
f"callback_cls in APIEmbedding.embed_documents must be a subclass of "
f"BaseCallback, got {type(callback_cls)}"
),
)
non_empty = [t for t in texts if t.strip()]
if len(non_empty) != len(texts):
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_INPUT_INVALID,
error_msg=f"{len(texts) - len(non_empty)} chunks are empty while embedding",
)
if not non_empty:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_INPUT_INVALID, error_msg="All texts are empty after filtering"
)
return non_empty
async def embed_query(self, text: str, **kwargs) -> List[float]:
if not text.strip():
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_INPUT_INVALID, error_msg="Empty text provided for embedding"
)
embeddings = await self._get_embeddings(text, **kwargs)
return embeddings[0]
def embed_query_sync(self, text: str, **kwargs) -> List[float]:
"""Embed a single query text (sync version)."""
if not text.strip():
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_INPUT_INVALID, error_msg="Empty text provided for embedding"
)
embeddings = self._get_embeddings_sync(text, **kwargs)
return embeddings[0]
async def embed_documents(
self,
texts: List[str],
batch_size: Optional[int] = None,
**kwargs,
) -> List[List[float]]:
non_empty = self.validate_embed_docs(texts, kwargs)
bsz = batch_size or self.max_batch_size or 1
if self.max_batch_size:
bsz = min(bsz, self.max_batch_size)
indices = list(range(0, len(non_empty), bsz))
callback_obj = kwargs.get("callback_cls", BaseCallback)(seq=indices)
async def process_batch(i: int) -> List[List[float]]:
"""Process a single batch with semaphore for concurrency control."""
async with self.limiter:
j = i + bsz
batch = non_empty[i:j]
embeddings = await self._get_embeddings(batch, **kwargs)
callback_obj(start_idx=i, end_idx=j, batch=batch)
return embeddings
tasks = [process_batch(i) for i in indices]
results = await asyncio.gather(*tasks)
all_embeddings = list(chain.from_iterable(results))
return all_embeddings
def embed_documents_sync(
self,
texts: List[str],
batch_size: Optional[int] = None,
**kwargs,
) -> List[List[float]]:
"""Embed document texts"""
non_empty = self.validate_embed_docs(texts, kwargs)
bsz = batch_size or self.max_batch_size or 1
if self.max_batch_size:
bsz = min(bsz, self.max_batch_size)
indices = list(range(0, len(non_empty), bsz))
callback_obj = kwargs.get("callback_cls", BaseCallback)(seq=indices)
tasks = {}
results = [None] * len(non_empty)
for i in indices:
j = i + bsz
tasks[self.executor.submit(self._get_embeddings_sync, non_empty[i:j], **kwargs)] = i
for task in as_completed(tasks):
i = tasks.get(task)
j = i + bsz
results[i:j] = batch = task.result()
callback_obj(start_idx=i, end_idx=j, batch=batch)
return results
async def _get_embeddings(self, text: str | List[str], **kwargs) -> List[List[float]]:
"""Get embedding vectors"""
kwargs.pop("callback_cls", None)
payload = {"model": self.model_name, "input": text, **kwargs}
for attempt in range(self.max_retries):
try:
resp = await asyncio.to_thread(
requests.post,
self.api_url,
json=payload,
headers=self._headers,
timeout=self.timeout,
verify=self._verify_ssl,
)
resp.raise_for_status()
result = resp.json()
if "embedding" in result:
emb = result["embedding"]
if isinstance(emb[0], list):
embeddings = emb
else:
embeddings = [emb]
elif "embeddings" in result:
embeddings = result["embeddings"]
elif "data" in result and isinstance(result["data"], list):
embeddings = []
for item in result["data"]:
if "embedding" in item:
embeddings.append(item["embedding"])
if not embeddings:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_RESPONSE_INVALID,
error_msg=f"No embeddings field found in data items: {result}",
)
else:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_RESPONSE_INVALID,
error_msg=f"No embeddings in response: {result}",
)
if self._dimension is None and embeddings and embeddings[0]:
self._dimension = len(embeddings[0])
logger.debug(f"Determined embedding dimension: {self._dimension}")
return embeddings
except requests.exceptions.RequestException as e:
if attempt == self.max_retries - 1:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_REQUEST_CALL_FAILED,
error_msg=f"Failed to get embedding after {self.max_retries} attempts",
cause=e,
) from e
logger.warning(
"Embedding request failed (attempt %s/%s): %s",
attempt + 1,
self.max_retries,
e,
)
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_UNREACHABLE_CALL_FAILED, error_msg="Unreachable code in _get_embeddings"
)
def _get_embeddings_sync(self, text: str | List[str], **kwargs) -> List[List[float]]:
"""Get embedding vectors (sync version)."""
kwargs.pop("callback_cls", None)
payload = {"model": self.model_name, "input": text, **kwargs}
for attempt in range(self.max_retries):
try:
resp = requests.post(
self.api_url,
json=payload,
headers=self._headers,
timeout=self.timeout,
verify=self._verify_ssl,
)
resp.raise_for_status()
result = resp.json()
if "embedding" in result:
emb = result["embedding"]
if isinstance(emb[0], list):
embeddings = emb
else:
embeddings = [emb]
elif "embeddings" in result:
embeddings = result["embeddings"]
elif "data" in result and isinstance(result["data"], list):
embeddings = []
for item in result["data"]:
if "embedding" in item:
embeddings.append(item["embedding"])
if not embeddings:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_RESPONSE_INVALID,
error_msg=f"No embeddings field found in data items: {result}",
)
else:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_RESPONSE_INVALID,
error_msg=f"No embeddings in response: {result}",
)
if self._dimension is None and embeddings and embeddings[0]:
self._dimension = len(embeddings[0])
logger.debug(f"Determined embedding dimension: {self._dimension}")
return embeddings
except requests.exceptions.RequestException as e:
if attempt == self.max_retries - 1:
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_REQUEST_CALL_FAILED,
error_msg=f"Failed to get embedding after {self.max_retries} attempts",
cause=e,
) from e
logger.warning(
"Embedding request failed (attempt %s/%s): %s",
attempt + 1,
self.max_retries,
e,
)
raise build_error(
StatusCode.RETRIEVAL_EMBEDDING_UNREACHABLE_CALL_FAILED, error_msg="Unreachable code in _get_embeddings_sync"
)