__all__ = ["assert_args_type"]
import inspect
from inspect import signature
from functools import wraps
from typing import Dict, Union, List, Tuple, Optional
from dataflow.flow_func import flow_func as ff
from dataflow.utils.msg_type_register import msg_type_register
import dataflow.dflow_wrapper as dwrapper
_global_type_checker_functions = {}
_global_running_device_id = None
_global_running_instance_id = None
_global_running_instance_num = None
_global_running_in_udf = False
_global_msg_type_register = msg_type_register
class DfException(Exception):
def __init__(self, obj, error_code: int = None):
super().__init__()
self.message = obj
self.error_code = (
error_code if error_code is not None else dwrapper.PARAM_INVALID
)
def __str__(self):
return f"DfException:{self.error_code}:{self.message}"
def __repr__(self):
return f"DfException:{self.error_code}:{self.message}"
class DfAbortException(DfException):
def __init__(self, obj, error_code: int = None):
super().__init__(obj, error_code)
def _register_type_checker_function(type_key, type_check_function):
_global_type_checker_functions[type_key] = type_check_function
def _is_instance(value, dtype):
if isinstance(dtype, type):
if not isinstance(value, dtype):
return False
elif hasattr(dtype, "__origin__") and (
dtype.__origin__ in _global_type_checker_functions
):
return _global_type_checker_functions[dtype.__origin__](value, dtype)
else:
raise TypeError(f"The type {value} check of data {dtype} is not supported")
return True
def _check_dict_key_value_type(arg, dtype):
if not isinstance(arg, Dict):
return False
dict_key_value_type = dtype.__args__
for key, value in arg.items():
if not _is_instance(key, dict_key_value_type[0]) or not _is_instance(
value, dict_key_value_type[1]
):
return False
return True
_register_type_checker_function(Dict.__origin__, _check_dict_key_value_type)
def _check_union_type(arg, dtype):
union_types = dtype.__args__
for union_type in union_types:
if _is_instance(arg, union_type):
return True
return False
_register_type_checker_function(Union, _check_union_type)
def _check_list_value_type(arg, dtype):
if not isinstance(arg, List):
return False
list_types = dtype.__args__
if len(list_types) < 1:
return False
list_type = list_types[0]
for value in arg:
if not _is_instance(value, list_type):
return False
return True
_register_type_checker_function(List.__origin__, _check_list_value_type)
def assert_args_type(*type_args, **type_kwargs):
def decorate(func):
sig = signature(func)
args_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
args_values = sig.bind(*args, **kwargs)
for arg, value in args_values.arguments.items():
if arg in args_types:
if not _is_instance(value, args_types[arg]):
raise DfException(f"Argument {arg} must be {args_types[arg]}")
return func(*args, **kwargs)
return wrapper
return decorate
def set_running_device_id(running_device_id):
global _global_running_device_id
_global_running_device_id = running_device_id
def get_running_device_id():
logger = ff.FlowFuncLogger()
if _global_running_device_id is None:
logger.error("running device id is not set")
return _global_running_device_id
def set_running_instance_id(running_instance_id):
global _global_running_instance_id
_global_running_instance_id = running_instance_id
def get_running_instance_id():
logger = ff.FlowFuncLogger()
if _global_running_instance_id is None:
logger.error("running instance id is not set")
return _global_running_instance_id
def set_running_instance_num(running_instance_num):
global _global_running_instance_num
_global_running_instance_num = running_instance_num
def get_running_instance_num():
logger = ff.FlowFuncLogger()
if _global_running_instance_num is None:
logger.error("running instance num is not set")
return _global_running_instance_num
def set_running_in_udf():
global _global_running_in_udf
_global_running_in_udf = True
def get_running_in_udf():
return _global_running_in_udf
def set_msg_type_register(type_register):
global _global_msg_type_register
_global_msg_type_register = type_register
def get_msg_type_register():
return _global_msg_type_register
def convert_flow_msg_to_object(flow_msg):
if int(flow_msg.get_msg_type()) == int(ff.MSG_TYPE_TENSOR_DATA):
return flow_msg.get_tensor()
elif get_msg_type_register().registered(flow_msg.get_msg_type()):
deserialize_func = get_msg_type_register().get_deserialize_func(
flow_msg.get_msg_type()
)
obj = deserialize_func(flow_msg.get_raw_data())
return obj
else:
return flow_msg.get_raw_data()
def get_param_count(func):
sig = inspect.signature(func)
return len([param for param in sig.parameters if param not in ("self", "cls")])
def get_typing_num_returns(func):
sig = inspect.signature(func)
if sig.return_annotation != inspect.Signature.empty:
if hasattr(sig.return_annotation, "__args__"):
return len(sig.return_annotation.__args__)
else:
return 1
else:
return -1