"""混淆和解混淆数据接口"""
import os
from typing import List
from .asset_obfuscation import AssetObfuscation
from ..constants import Constant, ErrorCode
from ..exception import ObfException
from ..model import TLSConfig, PskConfig
from ..utils import log, get_obf_dict_value_by_key, clean_bytearray, \
get_de_obf_dict_value_by_key, check_white_list, data_dec_mul, generate_obf_and_de_obf_dict
def _check_local_save_path(is_local_save, seed_ciphertext_dir) -> bool:
return is_local_save and isinstance(seed_ciphertext_dir, str) and not os.path.islink(seed_ciphertext_dir) \
and os.path.exists(seed_ciphertext_dir)
class DataAssetObfuscation(AssetObfuscation):
"""推理数据混淆对外接口
创建DataAssetObfuscation类实例
"""
def __init__(self, vocab_size: int, token_white_list: list = None):
"""参数:vocab_size:模型词汇表实际长度
token_white_list:词表tokenId白名单,该token不做混淆
"""
if not isinstance(vocab_size, int) or vocab_size <= 0:
log.error("The size of vocab must be greater than zero.")
raise ObfException(ErrorCode.VOCAB_SIZE_FAILED.value)
check_white_list(token_white_list, vocab_size)
self.vocab_size = vocab_size
self.white_set = [] if token_white_list is None else set(token_white_list)
def set_seed_safer(self, tls_info: tuple, psk_info: tuple) -> (int, str):
"""设置混淆因子,用于推理数据混淆
ca_file:ca证书路径 cert_file:证书路径 pri_keyfile:证书key路径 port:端口号 ks_path:ks工具路径
ciphertext_path:加密口令路径
psk_path:psk密文路径 ks_path_psk:工具路径 ciphertext_path_psk:加密口令路径
"""
if tls_info is None or len(tls_info) != Constant.TSL_INFO_PARAM_SIZE:
log.error("Tls info param size validation failed.")
return ErrorCode.INVALID_PARAM.value
if psk_info is None or len(psk_info) != Constant.PSK_INFO_PARAM_SIZE:
log.error("Psk info param size validation failed.")
return ErrorCode.INVALID_PARAM.value
ca_file, cert_file, pri_keyfile, port, ks_path, ciphertext_path = tls_info
psk_path, ks_path_psk, ciphertext_path_psk = psk_info
tls = TLSConfig(ca_file, cert_file, pri_keyfile, ks_path, ciphertext_path)
tls.set_port(port)
psk = PskConfig(psk_path, ks_path_psk, ciphertext_path_psk)
return self.export_set_obf_seed(tls, psk)
def set_seed_content(self, seed_content: str = None, is_local_save: bool = False,
seed_ciphertext_dir: str = None) -> (int, str):
"""
1.输入混淆因子内容,用于数据混淆
2.通过本地软保护设置混淆因子
:param seed_content: 混淆因子类型
:param is_local_save: 是否从本地获取
:param seed_ciphertext_dir: 密文保存路径,is_local_save为True时需要此参数需要被校验
:return: errorCode(int, str)
"""
if not isinstance(seed_content, str) and not _check_local_save_path(is_local_save, seed_ciphertext_dir):
log.error("The obfuscate seed content data type validation failed.")
return ErrorCode.INVALID_SEED_CONTENT.value
if seed_content is not None and (len(seed_content) > Constant.SEED_CONTENT_MAX_LEN
or len(seed_content) < Constant.SEED_CONTENT_MIN_LEN):
log.error("The obfuscate seed content len validation failed.")
return ErrorCode.INVALID_SEED_CONTENT.value
if seed_content is not None:
seed_content_bytes = bytearray(seed_content, "utf-8")
else:
seed_content_bytes = data_dec_mul(seed_ciphertext_dir, Constant.DATA_CIPHERTEXT_FILE_NAME)
set_seed_result = self._set_seed_core(seed_content_bytes, Constant.DATA_SEED_TYPE)
clean_bytearray(seed_content_bytes)
return set_seed_result
def data_2d_obf(self, tokens: List[List[int]]) -> List[List[int]]:
"""混淆二维数据
参数:tokens:待加混淆的tokens 注意:最内层列表内需为int数
返回值:obf_tokens:加混淆后的tokens
"""
obf_tokens = []
for token_row in tokens:
obf_row = []
for token in token_row:
self.__check_input_item(token)
token = get_obf_dict_value_by_key(token)
obf_row.append(token)
obf_tokens.append(obf_row)
return obf_tokens
def data_1d_obf(self, tokens: List[int]) -> List[int]:
"""混淆一维数据
参数:tokens:待加混淆的tokens 注意:列表内需为int数
返回值:obf_tokens:加混淆后的tokens
"""
obf_tokens = []
for token in tokens:
self.__check_input_item(token)
token = get_obf_dict_value_by_key(token)
obf_tokens.append(token)
return obf_tokens
def data_2d_deobf(self, tokens: List[List[int]]) -> List[List[int]]:
"""解混淆二维数据
参数:tokens:待解混淆的tokens 注意:最内层列表内需为int数
返回值:obf_tokens:解混淆后的tokens
"""
deobf_tokens = []
for row in tokens:
deobf_row = []
for token in row:
self.__check_input_item(token)
token = get_de_obf_dict_value_by_key(token)
deobf_row.append(token)
deobf_tokens.append(deobf_row)
return deobf_tokens
def data_1d_deobf(self, tokens: List[int]) -> List[int]:
"""解混淆一维数据
参数:tokens:待解混淆的tokens 注意:列表内需为int数
返回值:obf_tokens:解混淆后的tokens
"""
deobf_tokens = []
for token in tokens:
self.__check_input_item(token)
token = get_de_obf_dict_value_by_key(token)
deobf_tokens.append(token)
return deobf_tokens
def token_obf(self, token: int) -> int:
"""混淆int数据
参数:token:待加混淆的token
返回值;token:加混淆后的token
"""
if token >= self.vocab_size:
log.warning("The input token is out of range, return the original value.")
return token
return get_obf_dict_value_by_key(token)
def token_deobf(self, token: int) -> int:
"""解混淆int数据"""
return get_de_obf_dict_value_by_key(token)
def __check_input_item(self, item: int):
if not isinstance(item, int) or item >= self.vocab_size:
log.error("Item type must be int and item must less than vocab_size.")
raise ObfException(ErrorCode.ITEM_VALIDATE_FAILED.value)
def _set_seed_core(self, seed_content_bytes, seed_type):
log.info("Start to set seed core.")
try:
generate_obf_and_de_obf_dict(self.white_set, seed_content_bytes, self.vocab_size)
except ObfException as e:
return e.code, e.message
log.info("Set seed core successful.")
return ErrorCode.SUCCESS.value