import os
from typing import Callable, Any
from typing import List, Tuple
from motor.engine_server.constants.constants import (MAX_SIZE, MIN_SIZE, MIN_RANK_SIZE, MAX_RANK_SIZE,
MAX_FILE_NUMS, MIN_DEVICE_NUM, MAX_DEVICE_NUM, MUSK_PRIVILEGE)
class Validator:
"""
A validator to check the input parameters
"""
def __init__(self, value, msg="value is invalid"):
"""
:param value: the value for validation
:param msg: default error msg
"""
self.value = value
self.msg = msg
self.checkers = []
self.is_valid_state = None
def register_checker(self, checker: Callable[[Any], bool], msg: str = None):
self.checkers.append((checker, msg if msg else self.msg))
def check(self):
if self.is_valid_state is None:
self.is_valid_state = True
for ck, msg in self.checkers:
self.is_valid_state &= ck(self.value)
if not self.is_valid_state:
self.msg = msg
break
if self.is_valid_state:
self.msg = None
return self
def is_valid(self):
if self.is_valid_state is None:
try:
self.check()
except Exception as e:
self.is_valid_state = False
return self.is_valid_state
def get_value(self, default=None):
return self.value if self.is_valid() else default
class ClassValidator(Validator):
"""
Check class validator.
"""
def __init__(self, value, classes):
super().__init__(value)
self.classes = classes
def check_isinstance(self):
"""Check arg isinstance of classes"""
self.register_checker(lambda path: isinstance(self.value, self.classes), "Invalid parameter type")
return self
class StringValidator(Validator):
"""
String type validator.
"""
def __init__(self, value, max_len=None, min_len=0):
super().__init__(value)
self.max_len = max_len
self.min_len = min_len
self.register_checker(lambda x: isinstance(x, str), "type is not str")
def check_string_length(self):
if self.min_len is not None:
self.register_checker(lambda x: len(x) >= self.min_len, f"length is less than {self.min_len}")
if self.max_len is not None:
self.register_checker(lambda x: len(x) <= self.max_len, f"length is bigger than {self.max_len}")
return self
def check_not_contain_black_element(self, element):
self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1)
return self
def can_be_transformed2int(self, min_value: int = None, max_value: int = None):
if min_value is None:
min_value = MIN_RANK_SIZE
if max_value is None:
max_value = MAX_RANK_SIZE
can_transformed = self.value.isdigit()
try:
if can_transformed and (min_value > int(self.value) or max_value < int(self.value)):
can_transformed = False
except ValueError:
can_transformed = False
finally:
if self.is_valid_state is not None:
self.is_valid_state &= can_transformed
else:
self.is_valid_state = can_transformed
return self
class IntValidator(Validator):
"""
Int type validator
"""
def __init__(self, value: int, min_value: int = None, max_value: int = None):
super().__init__(value)
self.min_value = min_value
self.max_value = max_value
self.register_checker(lambda x: isinstance(x, int), "type is not int")
def check_value(self):
if self.min_value is not None:
self.register_checker(lambda x: x >= self.min_value, f"value is less than {self.min_value}")
if self.max_value is not None:
self.register_checker(lambda x: x <= self.max_value, f"value is bigger than {self.max_value}")
return self
class MapValidator(Validator):
"""
Map type validator.
"""
def __init__(self, value: dict, inclusive_keys: list = None):
super().__init__(value)
self.register_checker(lambda x: isinstance(x, dict), "type is not dict")
if inclusive_keys is None:
inclusive_keys = []
for key in inclusive_keys:
self.register_checker(lambda x, k=key: k in self.value, "Key error for the value of dict type")
class RankSizeValidator(IntValidator):
"""
Distributed training job size validator
"""
def check_rank_size_valid(self):
super().__init__(self.value)
self.register_checker(lambda x: MIN_RANK_SIZE <= self.value <= MAX_RANK_SIZE,
"Invalid rank size")
return self
def check_device_num_valid(self):
super().__init__(self.value)
self.register_checker(lambda x: MIN_DEVICE_NUM <= self.value <= MAX_DEVICE_NUM,
"Invalid device num")
return self
class DirectoryValidator(StringValidator):
def __init__(self, value, max_len=None, min_len=1):
"""
@param value: the path, should not be emtpy string, should not contain double dot(../)
"""
super().__init__(value, max_len, min_len)
self.register_checker(lambda x: isinstance(x, str), "type is not str")
@staticmethod
def remove_prefix(string: str | None, prefix: str | None) -> Tuple[bool, str | None]:
if string is None or prefix is None or len(string) < len(prefix):
return False, string
if string.startswith(prefix):
return True, string[len(prefix):]
else:
return False, string
@staticmethod
def check_is_children_path(path_: str, target_: str):
if not target_:
return False
try:
realpath_ = os.path.realpath(path_)
except (TypeError, ValueError, OSError):
return False
try:
realpath_target = os.path.realpath(target_)
except (TypeError, ValueError, OSError):
return False
is_prefix, rest_part = DirectoryValidator.remove_prefix(realpath_target, realpath_)
if rest_part.startswith(os.path.sep):
rest_part = rest_part.lstrip(os.path.sep)
if is_prefix:
joint_path = os.path.join(realpath_, rest_part)
return os.path.realpath(joint_path) == realpath_target
else:
return False
@staticmethod
def __check_with_sensitive_words(path: str, words: List):
_, name = os.path.split(path)
if name:
return not any(map(lambda x: x in path, words))
else:
return True
def check_directory_permissions(self, target_mode: int):
stat_info = os.stat(self.value)
mode = stat_info.st_mode & MUSK_PRIVILEGE
self.register_checker(lambda path: mode == target_mode, "permission error")
return self
def check_is_not_none(self):
self.register_checker(lambda path: self.value is not None and len(self.value) > 0,
"Invalid directory parameter")
return self
def check_dir_name(self):
def is_path_valid(path):
return not ('..' in self.value or len(self.value) < self.min_len or (
len(self.value) > self.max_len if self.max_len is not None else False))
self.register_checker(is_path_valid, "the path parameter is invalid")
return self
def check_not_soft_link(self):
self.register_checker(lambda path: os.path.realpath(self.value) == os.path.normpath(self.value),
"soft link or relative path should not be in the path parameter")
return self
def check_dir_file_number(self):
files = os.listdir(self.value)
self.register_checker(lambda path: len(files) <= MAX_FILE_NUMS,
"Too many files in directory")
return self
def path_should_exist(self, is_file=True, msg=None):
self.register_checker(lambda path: os.path.exists(self.value),
msg if msg else "path parameter does not exist")
if is_file:
self.register_checker(lambda path: os.path.isfile(self.value),
msg if msg else "path parameter is not a file")
return self
def path_should_not_exist(self):
self.register_checker(lambda path: not os.path.exists(self.value), "path parameter does not exist")
return self
def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None):
if lst is None:
lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64"]
if len(lst) == 0:
return self
if msg is None:
msg = "path should is in blacklist"
if exact_compare:
self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg)
else:
self.register_checker(
lambda path: not any([DirectoryValidator.check_is_children_path(each, path) for each in lst]), msg
)
return self
def should_not_contains_sensitive_words(self, words: List = None, msg=None):
if words is None:
words = ["Key", "password", "privatekey"]
self.register_checker(lambda path: DirectoryValidator.__check_with_sensitive_words(path, words), msg)
return self
def check_user_group(self):
process_uid = os.geteuid()
process_gid = os.getegid()
stat_info = os.stat(self.value)
dir_uid = stat_info.st_uid
dir_gid = stat_info.st_gid
self.register_checker(lambda path: process_uid == dir_uid or process_gid == dir_gid,
"Invalid dir user or group.")
return self
class FileValidator(StringValidator):
def __init__(self, value):
"""
@param value: the file path, should not be emtpy string, should not contain double dot(../)
"""
super().__init__(value)
self.register_checker(lambda x: isinstance(x, str), "type is not str")
self.register_checker(lambda x: os.path.isfile(x), "type is not file")
def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE):
self.register_checker(lambda path: min_size < os.path.getsize(self.value) <= max_size,
"file size is invalid")
return self
def check_not_soft_link(self):
self.register_checker(lambda path: os.path.realpath(self.value) == self.value,
"soft link or relative path should not be in the path parameter")
return self
def check_user_group(self):
process_uid = os.geteuid()
process_gid = os.getegid()
stat_info = os.stat(self.value)
file_uid = stat_info.st_uid
file_gid = stat_info.st_gid
self.register_checker(lambda path: process_uid == file_uid or process_gid == file_gid,
"Invalid file user or group.")
return self