import errno
import os.path
import shutil
import string
import tarfile
import zipfile
import platform
_CUR_DIR = os.path.dirname(__file__)
PATH_WHITE_LIST_LIN = string.digits + string.ascii_letters + ':~-+_./ '
MIN_PATH_LEN = 1
MAX_PATH_LEN = 4096
class ParameterTypes:
"""
定义Ansible模块参数类型的常量类
"""
STR = "str"
INT = "int"
DICT = "dict"
LIST = "list"
BOOL = "bool"
def get_validated_env(
env_name,
whitelist=PATH_WHITE_LIST_LIN,
check_symlink=True,
check_owner=True
):
"""
获取并验证环境变量 (兼容 Python 2/3)
:param env_name: 环境变量名称
:param whitelist: 允许的值列表
:param check_symlink: 是否检查软链接
:param check_owner: 属组检查
:return: 验证通过的环境变量值
:raises ValueError: 验证失败时抛出
"""
value = os.getenv(env_name)
if value is None:
return None
whitelist_check(value, whitelist)
length_check(env_name, MAX_PATH_LEN, MIN_PATH_LEN, value)
if check_symlink:
symlink_check(env_name, value)
if check_owner and os.path.lexists(value) and not owner_check(value):
raise ValueError("The path {} is not owned by current user or root.".format(value))
return value
def whitelist_check(value, whitelist):
for char in value:
if char not in whitelist:
raise ValueError(
"The path is invalid. The path can contain only char in '{}'".format(whitelist))
def owner_check(path):
path_stat = os.stat(path)
path_owner, path_gid = path_stat.st_uid, path_stat.st_gid
user_check = path_owner == os.getuid() and path_owner == os.geteuid()
return path_owner == 0 or path_gid in os.getgroups() or user_check
def length_check(env_name, max_length, min_length, value):
str_len = len(value)
if min_length is not None and str_len < min_length:
raise ValueError(
"Value for {} is too short. Minimum length: {}, actual: {}".format(
env_name, min_length, str_len
)
)
if max_length is not None and str_len > max_length:
raise ValueError(
"Value for {} is too long. Maximum length: {}, actual: {}".format(
env_name, max_length, str_len
)
)
def symlink_check(env_name, value):
if isinstance(value, bytes):
path_value = value.decode('utf-8', 'replace')
else:
path_value = value
try:
if os.path.lexists(path_value) and os.path.islink(path_value):
raise ValueError(
"Path for {} is a symlink: {}. Symlinks are not allowed for security reasons.".format(
env_name, path_value
)
)
except (OSError, IOError) as e:
if e.errno != errno.ENOENT:
raise ValueError(
"Error checking symlink for {}: {} - {}".format(env_name, path_value, str(e))
)
class ProjectPath:
USER_HOME = os.path.expanduser("~")
ROOT = os.path.dirname(_CUR_DIR)
if 'site-packages' in ROOT or 'dist-packages' in ROOT:
deployer_home = os.getcwd()
if platform.system() == 'Linux':
deployer_home = get_validated_env('ASCEND_DEPLOYER_HOME') or get_validated_env('HOME')
ROOT = os.path.join(deployer_home, 'ascend-deployer')
PLAYBOOK_DIR = os.path.join(ROOT, "playbooks")
INVENTORY_FILE = "inventory_file"
PROCESS_PLAYBOOK_DIR = os.path.join(PLAYBOOK_DIR, "process")
class TmpPath:
ROOT = os.path.join(ProjectPath.USER_HOME, ".ascend_deployer")
DEPLOY_INFO = os.path.join(ROOT, "deploy_info")
DL_YAML_DIR = os.path.join(ROOT, "dl_yaml")
PROGRESS_JSON_NAME = "deployer_progress_output.json"
PROGRESS_JSON = os.path.join(DEPLOY_INFO, PROGRESS_JSON_NAME)
TEST_REPORT_JSON = os.path.join(DEPLOY_INFO, "test_report.json")
CHECK_RES_OUTPUT_JSON = os.path.join(DEPLOY_INFO, "check_res_output.json")
class LargeScalePath:
ROOT_TMP_DIR = os.path.join(TmpPath.ROOT, "large_scale_deploy")
INVENTORY_FILE_PATH = os.path.join(ProjectPath.ROOT, "large_scale_inventory.ini")
PARSED_INVENTORY_FILE_PATH = os.path.join(ROOT_TMP_DIR, "parsed_inventory_file.ini")
DEPLOY_NODE_INVENTORY_FILE_PATH = os.path.join(ROOT_TMP_DIR, "deploy_node_inventory_file.ini")
REMOTE_DEPLOYER_DIR = os.path.join(ROOT_TMP_DIR, "ascend_deployer")
REMOTE_INVENTORY_FILE = os.path.join(REMOTE_DEPLOYER_DIR, ProjectPath.INVENTORY_FILE)
REMOTE_START_SCRIPT = os.path.join(REMOTE_DEPLOYER_DIR, "install.sh")
REMOTE_EXECUTE_RES_LOG = os.path.join(ROOT_TMP_DIR, "ascend_deployer_execute.log")
REMOTE_HOST_RESULTS = os.path.join(ROOT_TMP_DIR, "remote_host_data")
SPREAD_TASK = os.path.join(ROOT_TMP_DIR, "spread_task")
SPREAD_NODES_TREE_JSON = os.path.join(SPREAD_TASK, "spread_nodes_tree.json")
EXEC_RESULTS_DIR = os.path.join(SPREAD_TASK, "exec_results")
REPORT_DIR = os.path.join(ROOT_TMP_DIR, "report")
ALL_TEST_REPORT_CSV = os.path.join(REPORT_DIR, "test_report.csv")
class PathManager:
@classmethod
def recover_dir(cls, dir_path):
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(dir_path, mode=0o750, exist_ok=True)
@classmethod
def init_large_scale_dirs(cls):
cls.recover_dir(LargeScalePath.REMOTE_HOST_RESULTS)
cls.recover_dir(LargeScalePath.REPORT_DIR)
@classmethod
def clear_last_info_except_inventory(cls):
cls.recover_dir(LargeScalePath.REPORT_DIR)
all_remote_ip = os.listdir(LargeScalePath.REMOTE_HOST_RESULTS)
for ip in all_remote_ip:
remote_info_path = os.path.join(LargeScalePath.REMOTE_HOST_RESULTS, ip)
for file in os.listdir(remote_info_path):
if file != ProjectPath.INVENTORY_FILE:
os.remove(os.path.join(remote_info_path, file))
class CompressedFileCheckUtils:
"""
Utility class for checking compressed files (ZIP and TAR) for security issues.
This class provides methods to validate compressed files by checking for:
- Symbolic links (which can be a security risk)
- Path traversal attempts (e.g., files containing ../ sequences)
- Absolute paths (which can be a security risk)
The class supports both ZIP and TAR file formats and provides comprehensive
security checks to prevent potential security vulnerabilities when extracting
compressed files.
"""
WHITELIST_DIRS = [
"mpich-4.1.3/modules/"
]
@staticmethod
def is_in_whitelist(file_path):
"""检查文件路径是否在白名单目录中"""
return any(file_path.startswith(dir) for dir in CompressedFileCheckUtils.WHITELIST_DIRS)
@staticmethod
def check_tar_file_symbolic_link(file_info):
if CompressedFileCheckUtils.is_in_whitelist(file_info.path):
return True, ''
if file_info.issym():
err_msg = "[ASCEND][ERROR] The file:{} is a symbolic link, please check it.".format(file_info.path)
return False, err_msg
return True, ''
@staticmethod
def check_zip_file_symbolic_link(file_info):
if CompressedFileCheckUtils.is_in_whitelist(file_info.filename):
return True, ''
if file_info.compress_type == zipfile.ZIP_STORED and file_info.external_attr > 0o120000 << 16:
err_msg = "[ASCEND][ERROR] The file:{} is a symbolic link, please check it.".format(file_info.path)
return False, err_msg
return True, ''
@staticmethod
def check_package_inner_file_name(file_name):
check_str_list = ["../", "..\\", ".\\", "./", "~/"]
for check_str in check_str_list:
if check_str in file_name:
err_msg = "[ASCEND][ERROR] check compressed file:{} failed ,inner file has special string".format(
file_name)
return False, err_msg
if os.path.isabs(file_name):
err_msg = "[ASCEND][ERROR] check compressed file:{} failed ,inner file cannot be abspath".format(
file_name)
return False, err_msg
return True, ''
@staticmethod
def check_zip_file_info(filepath):
with zipfile.ZipFile(filepath, 'r') as file_list:
for file in file_list.infolist():
checks = [
CompressedFileCheckUtils.check_zip_file_symbolic_link(file),
CompressedFileCheckUtils.check_package_inner_file_name(file.filename)
]
for ret, err_msg in checks:
if not ret:
return False, err_msg
return True, ''
@staticmethod
def check_tar_file_info(filepath):
try:
with tarfile.open(filepath, 'r') as file_list:
for file in file_list:
checks = [
CompressedFileCheckUtils.check_tar_file_symbolic_link(file),
CompressedFileCheckUtils.check_package_inner_file_name(file.name)
]
for ret, err_msg in checks:
if not ret:
return False, err_msg
return True, ''
except (tarfile.TarError, IOError, OSError) as e:
return False, "[ASCEND][ERROR] Failed to check tar file {}: {}".format(filepath, str(e))
@staticmethod
def check_compressed_file_valid(filepath):
try:
if filepath.endswith((".tar.gz", ".tar")):
ret, err_msg = CompressedFileCheckUtils.check_tar_file_info(filepath)
if not ret:
return False, err_msg
return True, ""
elif filepath.endswith(".zip"):
ret, err_msg = CompressedFileCheckUtils.check_zip_file_info(filepath)
if not ret:
return False, err_msg
return True, ""
else:
err_msg = "[ASCEND][ERROR] unsupported compressed file format {}".format(filepath)
return False, err_msg
except (zipfile.BadZipfile, IOError, OSError) as e:
err_msg = "[ASCEND][ERROR] {}".format(str(e))
return False, err_msg