from abc import ABC, abstractmethod
import logging
from typing import Optional, Union
from pydantic import BaseModel
from openjiuwen.core.foundation.store.base_embedding import EmbeddingConfig
from openjiuwen.core.retrieval.embedding.dashscope_embedding import DashscopeEmbedding
from openjiuwen.core.retrieval.embedding.openai_embedding import OpenAIEmbedding
from openjiuwen_deepsearch.algorithm.search_nodes.utils import strip_quotes
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.utils.common_utils.url_utils import validate_embedding_service_url
from openjiuwen_deepsearch.utils.log_utils.log_manager import LogManager
logger = logging.getLogger(__name__)
class AbstractEmbedder(ABC):
default_instruction = "Given a web search query, retrieve relevant passages that answer the query"
model_name: str
embed_dim: int
def __init__(
self,
pretrained_model: str,
api_token: Union[str, bytearray],
api_url: str,
timeout: int = 60,
model_dim: Optional[int] = None,
):
if not api_token:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e="API token not provided."
),
)
if not api_url:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e="Embedding API URL not provided."
),
)
self.model_name = pretrained_model
self._user_provided_dim = model_dim is not None
self.embed_dim = (
model_dim
if model_dim is not None
else 0
)
if isinstance(api_token, str):
self.api_token = bytearray(api_token.encode("utf-8"))
else:
self.api_token = bytearray(api_token)
self.api_url = strip_quotes(api_url)
validate_embedding_service_url(self.api_url)
self.timeout = timeout
@abstractmethod
def get_query_instruction(self, query: str) -> str:
pass
@abstractmethod
def encode(
self,
input_texts: list[str],
is_query: bool = False,
) -> list[list[float]]:
pass
class OpenJiuwenAPIEmbedder(AbstractEmbedder):
default_instruction = AbstractEmbedder.default_instruction
def __init__(
self,
pretrained_model: str,
api_token: Union[str, bytearray],
api_url: str,
timeout: int = 60,
model_dim: Optional[int] = None,
max_retries: int = 3,
):
super().__init__(pretrained_model, api_token, api_url, timeout, model_dim)
token_str = strip_quotes(self.api_token.decode("utf-8"))
embed_config = EmbeddingConfig(
model_name=self.model_name,
base_url=self.api_url,
api_key=token_str,
)
embed_cls = self.jiuwen_embedding_class_for_url(self.api_url)
embed_kwargs: dict = {
"config": embed_config,
"timeout": timeout,
"max_retries": max_retries,
}
if self.embed_dim > 0:
embed_kwargs["dimension"] = self.embed_dim
self.embedder = embed_cls(**embed_kwargs)
if self.embed_dim <= 0:
self.embed_dim = int(self.embedder.dimension)
elif self._user_provided_dim:
observed_dim = int(self.embedder.dimension)
if observed_dim != self.embed_dim:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_EMBED_DIMENSION_MODEL_MISMATCH.code,
StatusCode.PARAM_CHECK_ERROR_EMBED_DIMENSION_MODEL_MISMATCH.errmsg.format(
dimension=self.embed_dim,
model=self.model_name,
status_code=200,
),
)
def get_query_instruction(self, query: str) -> str:
return f"Instruct: {self.default_instruction}\nQuery:{query}"
def encode(
self,
input_texts: list[str],
is_query: bool = False,
) -> list[list[float]]:
if is_query:
input_texts = [self.get_query_instruction(inp) for inp in input_texts]
try:
embeddings = self.embedder.embed_documents_sync(input_texts)
except Exception as e:
detail = ""
if not LogManager.is_sensitive():
detail = str(e)
raise CustomValueException(
StatusCode.EMBED_API_CALL_FAILED.code,
StatusCode.EMBED_API_CALL_FAILED.errmsg.format(
status_code=None,
detail=detail,
),
) from e
if embeddings:
observed_dim = len(embeddings[0])
if self.embed_dim <= 0:
self.embed_dim = observed_dim
elif self._user_provided_dim and observed_dim != self.embed_dim:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_EMBED_DIMENSION_MODEL_MISMATCH.code,
StatusCode.PARAM_CHECK_ERROR_EMBED_DIMENSION_MODEL_MISMATCH.errmsg.format(
dimension=self.embed_dim,
model=self.model_name,
status_code=200,
),
)
return embeddings
@classmethod
def jiuwen_embedding_class_for_url(cls, api_url: str):
"""DashScope native HTTP API vs OpenAI-compatible base URL (incl. DashScope compatible-mode)."""
url = api_url.lower()
if "dashscope.aliyuncs.com" in url and "compatible-mode" not in url:
return DashscopeEmbedding
return OpenAIEmbedding