import os
import sys
import stat
import re
import json
import logging
from enum import Enum
from typing import Optional
from ms_service_profiler.utils.log import logger
from ms_service_profiler.utils.constants import PATH_WHITE_LIST_REGEX
from ms_service_profiler.utils.constants import CONFIG_FILE_MAX_SIZE
MAX_SIZE_UNLIMITE = -1
MAX_SIZE_LIMITE_CONFIG_FILE = 10 * 1024 * 1024
MAX_SIZE_LIMITE_NORMAL_FILE = 4 * 1024 * 1024 * 1024
MAX_SIZE_LIMITE_MODEL_FILE = 100 * 1024 * 1024 * 1024
PATH_WHITE_LIST_REGEX_WIN = re.compile(r"[^_:\\A-Za-z0-9/.-]")
PERMISSION_NORMAL = 0o640
PERMISSION_KEY = 0o600
READ_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH
WRITE_FILE_NOT_PERMITTED_STAT = stat.S_IWGRP | stat.S_IWOTH
SOLUTION_LEVEL = 35
SOLUTION_LEVEL_WIN = 45
logging.addLevelName(SOLUTION_LEVEL, "\033[1;32m" + "SOLUTION" + "\033[0m")
logging.addLevelName(SOLUTION_LEVEL_WIN, "SOLUTION_WIN")
SOFT_LINK_SUB_CHAPTER = 'soft_link_warning_log_solution"'
PATH_LENGTH_SUB_CHAPTER = 'path_length_overflow_warning_log_solution"'
OWNER_SUB_CHAPTER = 'owner_or_ownergroup_warning_log_solution"'
PERMISSION_SUB_CHAPTER = 'path_permission_error_log_solution"'
ILLEGAL_CHAR_SUB_CHAPTER = 'path_contain_illegal_char_error_log_solution"'
RAW_INPUT_PATH = "RAW_INPUT_PATH"
MALICIOUS_CSV_PATTERN = re.compile(r'^[=+-+-=%@];[=+-+-=%@]')
def solution_log(content):
logger.log(SOLUTION_LEVEL, "visit %s for detailed solution", content)
def solution_log_win(content):
logger.log(SOLUTION_LEVEL_WIN, "visit %s for detailed solution", content)
def is_legal_path_length(path):
if len(path) > 4096 and not sys.platform.startswith("win"):
logger.warning("file total path %s length out of range (4096), please check the file(or directory) path", path)
solution_log(PATH_LENGTH_SUB_CHAPTER)
return True
if len(path) > 260 and sys.platform.startswith("win"):
logger.warning("file total path %s length out of range (260), please check the file(or directory) path", path)
solution_log_win(PATH_LENGTH_SUB_CHAPTER)
return True
dirnames = path.split("/")
for dirname in dirnames:
if len(dirname) > 255:
logger.warning("file name %s length out of range (255), please check the file(or directory) path", dirname)
solution_log(PATH_LENGTH_SUB_CHAPTER)
return True
return True
def is_match_path_white_list(path):
if PATH_WHITE_LIST_REGEX.search(path) and not sys.platform.startswith("win"):
logger.error("path: %s contains illegal char, legal chars include A-Z a-z 0-9 _ - / .", path)
solution_log(ILLEGAL_CHAR_SUB_CHAPTER)
return False
if PATH_WHITE_LIST_REGEX_WIN.search(path) and sys.platform.startswith("win"):
logger.error("path: %s contains illegal char, legal chars include A-Z a-z 0-9 _ - / . : \\", path)
solution_log_win(ILLEGAL_CHAR_SUB_CHAPTER)
return False
return True
def is_legal_args_path_string(path):
if not path:
return True
if not is_legal_path_length(path):
return False
if not is_match_path_white_list(path):
return False
return True
class SanitizeErrorType(Enum):
"""
The errors parameter Enum of the function sanitize_csv_value
"""
strict = "strict"
ignore = "ignore"
replace = "replace"
def sanitize_csv_value(value: str, errors=SanitizeErrorType.strict.value):
if errors == SanitizeErrorType.ignore.value or not isinstance(value, str):
return value
sanitized_value = value
try:
float(value)
except ValueError as e:
if not MALICIOUS_CSV_PATTERN.search(value):
pass
elif errors == SanitizeErrorType.replace.value:
sanitized_value = ' ' + value
else:
msg = f'Malicious value is not allowed to be written to the csv {value}'
logger.error("Please check the value written to the csv")
raise ValueError(msg) from e
return sanitized_value
class OpenException(Exception):
pass
class FileStat:
def __init__(self, file) -> None:
if not is_legal_path_length(file) or not is_match_path_white_list(file):
raise OpenException("Path name is too long or contains invalid characters.")
self.file = file
self.is_file_exist = os.path.exists(file)
if self.is_file_exist:
self.file_stat = os.stat(file)
self.realpath = os.path.realpath(file)
else:
self.file_stat = None
@property
def is_exists(self):
return self.is_file_exist
@property
def is_softlink(self):
return os.path.islink(self.file) if self.file_stat else False
@property
def is_file(self):
return stat.S_ISREG(self.file_stat.st_mode) if self.file_stat else False
@property
def is_dir(self):
return stat.S_ISDIR(self.file_stat.st_mode) if self.file_stat else False
@property
def file_size(self):
return self.file_stat.st_size if self.file_stat else 0
@property
def permission(self):
return stat.S_IMODE(self.file_stat.st_mode) if self.file_stat else 0o777
@property
def owner(self):
return self.file_stat.st_uid if self.file_stat else -1
@property
def group_owner(self):
return self.file_stat.st_gid if self.file_stat else -1
@property
def is_owner(self):
return self.owner == (os.geteuid() if hasattr(os, "geteuid") else 0)
@property
def is_group_owner(self):
return self.group_owner in (os.getgroups() if hasattr(os, "getgroups") else [0])
@property
def is_user_or_group_owner(self):
return self.is_owner or self.is_group_owner
@property
def is_user_and_group_owner(self):
return self.is_owner and self.is_group_owner
def is_basically_legal(self, perm='none', strict_permission=True):
if sys.platform.startswith("win"):
return self.check_windows_permission(perm)
else:
return self.check_linux_permission(perm, strict_permission=strict_permission)
def check_basic_permission(self, perm='none'):
if not self.is_exists and perm != 'write':
logger.error("path: %s not exist, please check if file or dir is exist", self.file)
return False
return True
def check_linux_permission(self, perm='none', strict_permission=True):
if not self.check_basic_permission(perm=perm):
return False
return True
def check_windows_permission(self, perm='none'):
if not self.check_basic_permission(perm=perm):
return False
return True
def is_legal_file_size(self, max_size):
if not self.is_file:
logger.error("path: %s is not a file", self.file)
return False
if self.file_size > max_size:
logger.error("file_size: %s byte out of max limit %s byte", self.file_size, max_size)
return False
else:
return True
def is_legal_file_type(self, file_types: list):
if not self.is_file and self.is_exists:
logger.error("path: %s is not a file", self.file)
return False
for file_type in file_types:
if os.path.splitext(self.file)[1] == f".{file_type}":
return True
logger.error("path: %s, file type not in %s", self.file, file_types)
return False
def check_file_exists_and_type(file_stat, file):
if file_stat.is_exists and file_stat.is_dir:
raise OpenException(f"Expecting a file, but it's a folder. {file}")
def check_file_size(file_stat, file, max_size):
if max_size is None:
logger.warning("Reading files should have a size limit control. %s", file)
raise OpenException(f"Reading files must have a size limit control. {file}")
if max_size != MAX_SIZE_UNLIMITE and max_size < file_stat.file_size:
logger.warning("The file size has exceeded the specifications and cannot be read. %s", file)
raise OpenException(f"The file size has exceeded the specifications and cannot be read. {file}")
def check_file_owner(file_stat, file):
if not file_stat.is_owner:
logger.warning(
"The file owner is inconsistent with the current process user and is not allowed to write. %s", file
)
raise OpenException(
f"The file owner is inconsistent with the current process user and is not allowed to write. {file}"
)
def ms_open(
file, mode="r", max_size=CONFIG_FILE_MAX_SIZE, softlink=False, write_permission=PERMISSION_NORMAL, **kwargs
):
del softlink, write_permission
file_stat = FileStat(file)
check_file_exists_and_type(file_stat, file)
if "r" in mode and not any(flag in mode for flag in ("w", "a", "x", "+")):
if not file_stat.is_exists:
raise OpenException(f"No such file or directory. {file}")
check_file_size(file_stat, file, max_size)
elif any(flag in mode for flag in ("w", "a", "x", "+")) and file_stat.is_exists:
check_file_owner(file_stat, file)
try:
if "b" in mode:
return open(file, mode, **kwargs)
encoding = kwargs.pop("encoding", "utf-8")
return open(file, mode, encoding=encoding, **kwargs)
except OSError as exc:
raise OpenException(str(exc)) from exc
class UmaskWrapper:
"""Write with preset umask
>>> with UmaskWrapper():
>>> ...
"""
def __init__(self, umask=0o027):
self.umask, self.ori_umask = umask, None
def __enter__(self):
self.ori_umask = os.umask(self.umask)
def __exit__(self, exc_type=None, exc_val=None, exc_tb=None):
os.umask(self.ori_umask)
def get_valid_lib_path(so_name: str) -> Optional[str]:
allowed_libs = {"libms_service_profiler.so"}
if so_name not in allowed_libs:
logging.error("%s is invalid.", so_name)
return None
ascend_home = os.getenv("ASCEND_HOME_PATH")
if not ascend_home:
return so_name
candidate_path = os.path.join(ascend_home, "aarch64-linux", "lib64", so_name)
real_path = os.path.realpath(candidate_path)
if os.path.exists(real_path) and os.access(real_path, os.R_OK):
return real_path
else:
return so_name
def safe_json_dump(obj, *args, **kwargs):
if not isinstance(obj, (str, int, float, bool, type(None), list, dict, tuple)):
raise TypeError(f"Object of type {type(obj)} is not safe JSON serializable")
return json.dumps(obj, *args, **kwargs)