"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import io
import os
from abc import ABC
import zipfile
import base64
import psutil
from loguru import logger
from PIL import Image
from mx_rag.llm import Img2TextLLM
from mx_rag.utils import file_check
from mx_rag.utils.common import validate_params, STR_TYPE_CHECK_TIP_1024, MAX_IMAGE_PIXELS, MIN_IMAGE_WIDTH, \
MIN_IMAGE_HEIGHT, MAX_BASE64_SIZE
class BaseLoader(ABC):
MAX_SIZE = 100 * 1024 * 1024
MAX_PAGE_NUM = 1000
MAX_WORD_NUM = 500000
MAX_FILE_CNT = 1024
MAX_NESTED_DEPTH = 10
@validate_params(
file_path=dict(validator=lambda x: isinstance(x, str) and 0 < len(x) <= 1024, message=STR_TYPE_CHECK_TIP_1024),
)
def __init__(self, file_path):
self.file_path = file_path
self.multi_size = 5
file_check.SecFileCheck(self.file_path, self.MAX_SIZE).check()
def _is_zip_bomb(self):
try:
with zipfile.ZipFile(self.file_path, "r") as zip_ref:
file_count = len(zip_ref.infolist())
if file_count >= self.MAX_FILE_CNT * self.multi_size:
logger.error(f'zip file ({self.file_path}) contains {file_count} files, exceed '
f'the limit of {self.MAX_FILE_CNT * self.multi_size}')
return True
total_uncompressed_size = sum(zinfo.file_size for zinfo in zip_ref.infolist())
if total_uncompressed_size > self.MAX_SIZE * self.multi_size:
logger.error(f"zip file '{self.file_path}' uncompressed size is {total_uncompressed_size} bytes"
f"exceeds the limit of {self.MAX_SIZE * self.multi_size} bytes, Potential ZIP bomb")
return True
remain_size = psutil.disk_usage(os.getcwd()).free
if remain_size - total_uncompressed_size < self.MAX_SIZE * 2:
logger.error(f'zip file ({self.file_path}) uncompressed size is {total_uncompressed_size} bytes'
f' only {remain_size} bytes of disk space available')
return True
return False
except zipfile.BadZipfile as e:
logger.error(f"The provided path '{self.file_path}' is not a valid zip file or is corrupted: {e}")
return True
except Exception as e:
logger.error(f"Unexpected error occurred while checking zip bomb: {e}")
return True
def _verify_image_size(self, image_bytes):
"""Verify if the image dimensions are within acceptable limits."""
try:
with Image.open(io.BytesIO(image_bytes)) as img:
width, height = img.size
total_pixels = width * height
if total_pixels > MAX_IMAGE_PIXELS:
logger.warning(f"Image too large: {width}x{height} pixels. Skipping.")
return False
elif width < MIN_IMAGE_WIDTH and height < MIN_IMAGE_HEIGHT:
logger.warning(f"Image too small: {width}x{height} pixels. Skipping.")
return False
return True
except OSError as err:
logger.warning(f"Failed to open image file: {err}")
return False
except ValueError as err:
logger.warning(f"Invalid image data: {err}")
return False
except MemoryError as err:
logger.warning(f"Insufficient memory to process image: {err}")
return False
except Exception as err:
logger.warning(f"Failed to verify image size: {err}")
return False
def _convert_to_base64(self, image_data,
max_base64_size=MAX_BASE64_SIZE,
max_iterations=10):
"""
将图片转为Base64编码,并控制大小不超过 max_base64_size。
- 如果原始数据直接符合要求,直接返回;
- 否则按比例缩小分辨率,直到符合大小要求。
"""
raw_b64_len = len(image_data) * 4 // 3
if raw_b64_len <= max_base64_size:
return base64.b64encode(image_data).decode("utf-8")
with Image.open(io.BytesIO(image_data)) as image:
image = image.convert("RGB")
ratio = (max_base64_size / raw_b64_len) ** 0.5
new_w = int(image.width * ratio)
new_h = int(image.height * ratio)
resized_img = image.resize((new_w, new_h), Image.LANCZOS)
buffer = io.BytesIO()
resized_img.save(buffer, format="JPEG", quality=90)
compressed_data = buffer.getvalue()
return base64.b64encode(compressed_data).decode("utf-8")
def _interpret_image(self, image_data, vlm: Img2TextLLM):
img_base64 = self._convert_to_base64(image_data)
if self._verify_image_size(image_data) is False:
logger.warning("image size is invalid")
img_base64, img_summary = "", ""
return img_base64, img_summary
image_url = {"url": f"data:image/jpeg;base64,{img_base64}"}
image_summary = vlm.chat(image_url=image_url)
if image_summary == "":
img_base64 = ""
logger.warning("image summary func exec failed")
return img_base64, image_summary