"""图片数据混淆接口"""
import base64
import json
import math
import os
from io import BytesIO
from typing import Tuple
import torch
from PIL import Image
from torchvision.transforms.v2 import functional as F
from ..api.asset_obfuscation import AssetObfuscation
from ..constants import Constant, ErrorCode
from ..exception import ObfException
from ..utils import log, clean_bytearray, data_dec_mul, \
generate_patch_and_channel_permute, apply_patch_and_channel_permute
def _check_local_save_path(is_local_save, seed_ciphertext_dir) -> bool:
return is_local_save and isinstance(seed_ciphertext_dir, str) and not os.path.islink(seed_ciphertext_dir) \
and os.path.exists(seed_ciphertext_dir)
class ImageDataAssetObfuscation(AssetObfuscation):
"""图片推理数据混淆对外接口
创建ImageDataAssetObfuscation类实例
"""
def __init__(self, patch_size: int = 16, merge_size: int = 2, longest_edge: int = 16777216,
shortest_edge: int = 65536, temporal_patch_size: int = 2):
args = [patch_size, merge_size, longest_edge, shortest_edge, temporal_patch_size]
if not all(isinstance(arg, int) for arg in args):
log.error("The parameters for visual data preprocessing must be of int type.")
raise ObfException(ErrorCode.INVALID_PARAM.value)
self.patch_size = patch_size
self.merge_size = merge_size
self.longest_edge = longest_edge
self.shortest_edge = shortest_edge
self.temporal_patch_size = temporal_patch_size
self.factor = self.patch_size * self.merge_size
if self.factor == 0 or self.longest_edge == 0:
log.error("The 'factor' or 'longest_edge' cannot be zero.")
raise ObfException(ErrorCode.INVALID_PARAM.value)
self.is_seed_content = False
@staticmethod
def _to_rgb(pil_image: Image.Image) -> Image.Image:
"""
将图片颜色格式转为RGB
:param pil_image: 图片对象
:return: RGB格式的图片对象
"""
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3])
return white_background
else:
return pil_image.convert("RGB")
@staticmethod
def _smart_resize(height: int, width: int, factor: int, min_pixels: int, max_pixels: int) -> Tuple[int, int]:
"""
缩放图片,以满足以下条件:
1. 高度和宽度都能被“factor”整除
2. 像素总数在'min_pixels'和'max_pixels'范围内
3. 尽可能保持图像的宽高比
:param height: 高度
:param width: 宽度
:param factor: 高度和宽度必须能被整除的因子
:param min_pixels: 允许的最小像素总数
:param max_pixels: 允许的最大像素总数
:return: 包含新高度和宽度的元组
"""
if max(height, width) / min(height, width) > Constant.IMAGE_MAX_RATIO:
log.error(f"Absolute aspect ratio must be smaller than {Constant.IMAGE_MAX_RATIO}")
raise ObfException(ErrorCode.INVALID_IMAGE_ASPECT_RATO.value)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
@staticmethod
def _tensor_to_pil(tensor) -> Image.Image:
"""
张量数据转PIL图片对象
:param tensor: 输入的张量
:return: PIL.Image对象
"""
if tensor.dtype == torch.float32:
tensor = (tensor * 255).clamp(0, 255)
tensor = tensor.byte()
return F.to_pil_image(tensor)
@staticmethod
def _validate_base64_image(image: str) -> Tuple[str, Image.Image]:
"""
验证并解码base64编码的图像字符串
:param image: base64编码格式的图片数据
:return: base64前缀字符串和PIL.Image对象
"""
if not isinstance(image, str) or len(image) > Constant.MAX_BASE64_IMAGE_LENGTH * 1024 * 1024:
log.error("The image data type or length validation failed.")
raise ObfException(ErrorCode.INVALID_PARAM.value)
base64_field = "base64,"
base64_prefix = ""
if base64_field in image:
base64_prefix, image = image.split(base64_field, 1)
base64_prefix += base64_field
try:
decoded_bytes = base64.b64decode(image, validate=True)
except Exception as e:
log.error("Failed to decode base64 string.")
raise ObfException(ErrorCode.INVALID_BASE64_STRING.value) from e
image_stream = None
pil_img = None
try:
image_stream = BytesIO(decoded_bytes)
pil_img = Image.open(image_stream)
return base64_prefix, pil_img
except Exception as e:
if pil_img:
pil_img.close()
log.error("Failed to open image.")
raise ObfException(ErrorCode.INVALID_IMAGE.value) from e
@classmethod
def create_by_config(cls, config_path: str):
"""
通过配置文件构造实例
:param config_path: 模型图片预处理配置文件路径
:return: ImageDataAssetObfuscation对象
"""
if not isinstance(config_path, str) or not os.path.isfile(config_path):
log.error(f"Invalid preprocess config json file path")
raise ObfException(ErrorCode.INVALID_PATH.value)
try:
with open(config_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except Exception as e:
log.error(f"Failed to load the preprocess config json file")
raise ObfException(ErrorCode.INVALID_JSON_FILE.value) from e
patch_size = data.get("patch_size")
merge_size = data.get("merge_size")
size = data.get("size")
longest_edge = size.get("longest_edge") if isinstance(size, dict) else None
shortest_edge = size.get("shortest_edge") if isinstance(size, dict) else None
temporal_patch_size = data.get("temporal_patch_size")
return cls(patch_size, merge_size, longest_edge, shortest_edge, temporal_patch_size)
def set_seed_content(self, seed_content: str = None, is_local_save: bool = False,
seed_ciphertext_dir: str = None) -> (int, str):
"""
1.输入混淆因子内容,用于数据混淆
2.通过本地软保护设置混淆因子
:param seed_content: 混淆因子类型
:param is_local_save: 是否从本地获取
:param seed_ciphertext_dir: 密文保存路径,is_local_save为True时需要此参数需要被校验
:return: errorCode(int, str)
"""
if not isinstance(seed_content, str) and not _check_local_save_path(is_local_save, seed_ciphertext_dir):
log.error("The obfuscate seed content data type validation failed.")
return ErrorCode.INVALID_SEED_CONTENT.value
if seed_content is not None and (len(seed_content) > Constant.SEED_CONTENT_MAX_LEN
or len(seed_content) < Constant.SEED_CONTENT_MIN_LEN):
log.error("The obfuscate seed content len validation failed.")
return ErrorCode.INVALID_SEED_CONTENT.value
if seed_content is not None:
seed_content_bytes = bytearray(seed_content, "utf-8")
else:
seed_content_bytes = data_dec_mul(seed_ciphertext_dir, Constant.DATA_CIPHERTEXT_FILE_NAME)
set_seed_result = self._set_seed_core(seed_content_bytes, Constant.DATA_SEED_TYPE)
clean_bytearray(seed_content_bytes)
return set_seed_result
def image_base64_obf(self, image: str) -> str:
"""
混淆base64编码格式的图片数据
:param image: base64编码格式的图片数据
:return: 混淆后图片的base64编码
"""
if not self.is_seed_content:
log.error("The seed content is not set.")
raise ObfException(ErrorCode.SEED_CONTENT_NOT_SET.value)
base64_prefix, pil_img = self._validate_base64_image(image)
try:
obf_pil_image = self._image_obf(pil_img)
except Exception as e:
log.error(f"Failed to obfuscate image: {e}")
raise ObfException(ErrorCode.FAILED.value) from e
obf_image_bytes = BytesIO()
obf_pil_image.save(obf_image_bytes, "JPEG")
obf_base64_data = obf_image_bytes.getvalue()
obf_base64_data = base64.b64encode(obf_base64_data).decode('utf-8')
return base64_prefix + obf_base64_data
def image_bytearray_obf(self, image: bytearray) -> bytearray:
"""
混淆bytearray格式图片流
:param image: bytearray格式图片流
:return: 混淆后的bytearray格式图片流
"""
if not self.is_seed_content:
log.error("The seed content is not set.")
raise ObfException(ErrorCode.SEED_CONTENT_NOT_SET.value)
if not isinstance(image, bytearray) or len(image) > Constant.MAX_BYTES_IMAGE_LENGTH * 1024 * 1024:
log.error("The image data type or length validation failed.")
raise ObfException(ErrorCode.INVALID_PARAM.value)
pil_img = None
try:
image_stream = BytesIO(image)
pil_img = Image.open(image_stream)
try:
obf_pil_image = self._image_obf(pil_img)
except Exception as e:
log.error(f"Failed to obfuscate image: {e}")
raise ObfException(ErrorCode.FAILED.value) from e
obf_image_bytes = BytesIO()
obf_pil_image.save(obf_image_bytes, "JPEG")
return bytearray(obf_image_bytes.getvalue())
except Exception as e:
log.error("Failed to open image.")
raise ObfException(ErrorCode.INVALID_IMAGE.value) from e
finally:
if pil_img:
pil_img.close()
def _image_obf(self, image: Image.Image) -> Image.Image:
"""
混淆PIL.Image对象
:param image: PIL.Image对象
:return: 混淆后的PIL.Image对象
"""
image = self._to_rgb(image)
width, height = image.size
resized_height, resized_width = self._smart_resize(
height,
width,
factor=self.factor,
min_pixels=self.shortest_edge,
max_pixels=self.longest_edge,
)
image = image.resize((resized_width, resized_height))
image = F.pil_to_tensor(image)
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
channel = image.shape[0]
patches = image.reshape(channel, grid_h, self.patch_size, grid_w, self.patch_size)
patches = patches.permute(0, 1, 3, 2, 4)
patches_flat = patches.reshape(channel, grid_h, grid_w, self.patch_size * self.patch_size)
numpy_array = patches_flat.numpy()
numpy_array = apply_patch_and_channel_permute(numpy_array)
permuted_flat = torch.tensor(numpy_array)
permuted_patches = permuted_flat.reshape(channel, grid_h, grid_w, self.patch_size, self.patch_size)
permuted_patches = permuted_patches.permute(0, 1, 3, 2, 4)
restored_image = permuted_patches.reshape(channel, grid_h * self.patch_size, grid_w * self.patch_size)
permuted_image = self._tensor_to_pil(restored_image)
return permuted_image
def _set_seed_core(self, seed_content_bytes, seed_type):
"""
通过混淆因子生成混淆列表
:param seed_content_bytes: 混淆因子
:return: errorCode(int, str)
"""
log.info("Start to set seed core.")
try:
generate_patch_and_channel_permute(seed_content_bytes, self.patch_size * self.patch_size,
Constant.SEED_ADD_IN.encode("utf-8"))
except ObfException as e:
return e.code, e.message
log.info("Set seed core successful.")
self.is_seed_content = True
return ErrorCode.SUCCESS.value