"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import re
import json
from typing import List
from langchain_core.embeddings import Embeddings
from loguru import logger
from mx_rag.utils import ClientParam
from mx_rag.utils.common import (
MAX_URL_LENGTH,
EMBEDDING_IMG_COUNT,
EMBEDDING_TEXT_COUNT,
STR_MAX_LEN,
MAX_BATCH_SIZE,
MB
)
from mx_rag.utils.common import validate_params, validate_list_str
from mx_rag.utils.file_check import FileCheckError, PathNotFileException
from mx_rag.utils.url import RequestUtils
def _validate_image_data_uri(image_data_uri):
pattern = r'^[A-Za-z0-9+/=]+$'
match = re.match(pattern, image_data_uri)
return match is not None
class CLIPEmbeddingError(Exception):
pass
class CLIPEmbedding(Embeddings):
@validate_params(
url=dict(validator=lambda x: isinstance(x, str) and 0 <= len(x) <= MAX_URL_LENGTH,
message=f"param must be str and str length range [0, {MAX_URL_LENGTH}]"),
client_param=dict(validator=lambda x: isinstance(x, ClientParam),
message="param must be instance of ClientParam"))
def __init__(self, url: str, client_param=ClientParam()):
self.url = url
self.client = None
self.headers = {'Content-Type': 'application/json'}
try:
self.client = RequestUtils(client_param=client_param)
except FileCheckError as e:
logger.error(f"CLIP client file param check failed:{e}")
except PathNotFileException as e:
logger.error(f"CLI client crt is not a file:{e}")
except Exception as e:
raise CLIPEmbeddingError('CLIP client init failed') from e
@staticmethod
def create(**kwargs):
if "url" not in kwargs or not isinstance(kwargs.get("url"), str):
logger.error("url param error. ")
return None
return CLIPEmbedding(**kwargs)
@validate_params(
texts=dict(
validator=lambda x: validate_list_str(x, [1, EMBEDDING_TEXT_COUNT], [1, STR_MAX_LEN]),
message=f"param must meets: Type is List[str], list length range [1, {EMBEDDING_TEXT_COUNT}], "
f"str length range [1, {STR_MAX_LEN}]"),
batch_size=dict(
validator=lambda x: isinstance(x, int) and 1 <= x <= MAX_BATCH_SIZE,
message=f"param must be int and value range [1, {MAX_BATCH_SIZE}]"))
def embed_documents(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
return self._encode(texts, batch_size=batch_size)
@validate_params(
text=dict(
validator=lambda x: (isinstance(x, str)) and 1 <= len(x) <= STR_MAX_LEN,
message=f"param must be str and length range [1, {STR_MAX_LEN}]"))
def embed_query(self, text: str) -> List[float]:
embeddings = self.embed_documents([text])
if not embeddings:
raise ValueError("Failed to embed text")
return embeddings[0]
@validate_params(
images=dict(
validator=lambda x: validate_list_str(x, [1, EMBEDDING_IMG_COUNT], [1, 10 * MB]),
message=f"param must meets: Type is List[str], list length range [1, {EMBEDDING_IMG_COUNT}],"
f" str length range [1, {10 * MB}]"),
batch_size=dict(
validator=lambda x: isinstance(x, int) and 1 <= x <= MAX_BATCH_SIZE,
message=f"param must be int and value range [1, {MAX_BATCH_SIZE}]"))
def embed_images(self, images: List[str], batch_size: int = 32) -> List[List[float]]:
if not all(_validate_image_data_uri(image_uri) for image_uri in images):
raise ValueError("wrong image string, it must match r'^[A-Za-z0-9+/=]+$'")
return self._encode(blobs=images, batch_size=batch_size)
def _encode(self, texts: List[str] = None, blobs: List[str] = None, batch_size: int = 32) -> List[List[float]]:
texts = texts or []
blobs = blobs or []
inputs = [{'text': text} for text in texts]
inputs.extend([{'blob': blob} for blob in blobs])
result = []
for start_index in range(0, len(inputs), batch_size):
batched_inputs = inputs[start_index: start_index + batch_size]
request_body = {"data": batched_inputs, "parameters": {"drop_image_content": True}}
try:
resp = self.client.post(self.url, json.dumps(request_body), headers=self.headers)
if not resp.success:
raise CLIPEmbeddingError('failed to get response')
resp_data = json.loads(resp.data or "{}")
if not isinstance(resp_data, dict) or "data" not in resp_data:
raise TypeError("Response data should be a dictionary containing a 'data' key")
if len(resp_data["data"]) != len(batched_inputs):
raise ValueError("Response data length does not match input batch size")
embeddings = [item['embedding'] for item in resp_data["data"] if 'embedding' in item]
result.extend(embeddings)
except json.JSONDecodeError as e:
logger.error(f"failed to parse json response for batch starting at {start_index}: {e}")
raise CLIPEmbeddingError(f"unable to parse clip response content: {e}") from e
except Exception as e:
logger.error(f"Unexpected error while processing batch starting at {start_index}: {e}")
raise CLIPEmbeddingError(
f"Failed to process CLIP response for batch starting at {start_index}: {e}") from e
return result