from typing import List, Tuple, Any, Callable, Dict, Optional, Union, Type
import re
import os
import inspect
import functools
import stat
import tensorflow as tf
from rec_sdk_common.validator.validator import Validator, IntValidator, FloatValidator, StringValidator
from rec_sdk_common.constants.constants import FileParams
from rec_sdk_common.log.log import LoggingProxy
class TensorShapeValidator(Validator):
def __init__(self, name: str, value: tf.TensorShape, int_checker_args: dict = None, msg: str = ""):
super().__init__(name, value)
self.int_checker_args = int_checker_args if int_checker_args else {}
self.msg = msg
self.register()
def register(self):
def check_tensor_shape():
if isinstance(self.value, tf.TensorShape) and self.value.ndims == 1:
value = self.value.as_list()[0]
else:
return False
int_checker = IntValidator(
name=self.name,
value=value,
**self.int_checker_args,
)
int_checker.check_value().check()
return True
self.register_checker(check_tensor_shape,
self.msg if self.msg else f"type of '{self.name}' is not TensorShape or ndims is not 1")
class LearningRateValidator(FloatValidator):
def __init__(self, name: str, value: Union[tf.Tensor, float], min_value: float, max_value: float):
if isinstance(value, tf.Tensor):
with (tf.Session() if tf.__version__.startswith("1.") else tf.compat.v1.Session()) as sess:
try:
value = sess.run(value).item()
except Exception as e:
LoggingProxy.warning("[Validator] Parameter %s is passed, and an exception occurred while getting "
"the value in the tensor: \n%s\n. Ensure that the passed parameter is a "
"constant tensor or a tf.PlaceHolder that feeds a constant value. Otherwise, "
"an exception may occur.", value, e)
value = 0.0 if min_value is None else float(min_value)
super().__init__(name, value, min_value=min_value, max_value=max_value)
class FileValidator(StringValidator):
"""
Check if file is valid.
"""
def __init__(self, name, value):
"""
@param value: the file path, should not be emtpy string, should not contain double dot(../)
"""
super(FileValidator, self).__init__(name, value)
self.register_checker(lambda: isinstance(self.value, str), "parameter value's type is not str")
def check_file_size(self, max_size=FileParams.MAX_SIZE.value, min_size=FileParams.MIN_SIZE.value):
file_stat = tf.io.gfile.stat(self.value)
self.register_checker(lambda: min_size < file_stat.length <= max_size,
f"file size: {file_stat.length} is invalid, not in ({min_size}, {max_size}]")
return self
def check_not_soft_link(self):
self.register_checker(lambda: os.path.abspath(self.value) == os.path.realpath(self.value),
f"soft link or relative path: {self.value} should not be in the path parameter")
return self
def check_user_group(self):
process_uid = os.geteuid()
process_gid = os.getegid()
stat_info = os.stat(self.value)
file_uid = stat_info.st_uid
file_gid = stat_info.st_gid
self.register_checker(lambda: process_uid == file_uid or process_gid == file_gid,
"Invalid log file user or group.")
return self
def check_file_mode(self, unsupported_mode=0o022):
stat_info = os.stat(self.value)
mode = stat.S_IMODE(stat_info.st_mode)
self.register_checker(lambda: mode & unsupported_mode == 0,
f"Current file mode {oct(mode)} is unsupported")
return self