from enum import Enum
import threading
import functools
import traceback
import logging
from typing import Any, Optional, Dict, Union, List, Tuple
import queue
import numpy as np
import dataflow.dataflow as df
import dataflow.data_type as dt
import dataflow.utils.utils as utils
from dataflow.flow_func import flowfunc_wrapper as fw
from dataflow import data_wrapper
from dataflow import dflow_wrapper as dw
fw.init_func_datatype_manager(
{
data_wrapper.FuncDataType.DT_FLOAT: np.array([], np.float32),
data_wrapper.FuncDataType.DT_FLOAT16: np.array([], np.float16),
data_wrapper.FuncDataType.DT_INT8: np.array([], np.int8),
data_wrapper.FuncDataType.DT_INT16: np.array([], np.int16),
data_wrapper.FuncDataType.DT_UINT16: np.array([], np.uint16),
data_wrapper.FuncDataType.DT_UINT8: np.array([], np.uint8),
data_wrapper.FuncDataType.DT_INT32: np.array([], np.int32),
data_wrapper.FuncDataType.DT_INT64: np.array([], np.int64),
data_wrapper.FuncDataType.DT_UINT32: np.array([], np.uint32),
data_wrapper.FuncDataType.DT_UINT64: np.array([], np.uint64),
data_wrapper.FuncDataType.DT_BOOL: np.array([], np.bool_),
data_wrapper.FuncDataType.DT_DOUBLE: np.array([], np.double),
}
)
FLOW_FUNC_SUCCESS = fw.FLOW_FUNC_SUCCESS
FLOW_FUNC_FAILED = fw.FLOW_FUNC_FAILED
FLOW_FUNC_ERR_PARAM_INVALID = fw.FLOW_FUNC_ERR_PARAM_INVALID
FLOW_FUNC_ERR_ATTR_NOT_EXITS = fw.FLOW_FUNC_ERR_ATTR_NOT_EXITS
FLOW_FUNC_ERR_ATTR_TYPE_MISMATCH = fw.FLOW_FUNC_ERR_ATTR_TYPE_MISMATCH
FLOW_FUNC_ERR_TIME_OUT_ERROR = fw.FLOW_FUNC_ERR_TIME_OUT_ERROR
FLOW_FUNC_STATUS_REDEPLOYING = fw.FLOW_FUNC_STATUS_REDEPLOYING
FLOW_FUNC_STATUS_EXIT = fw.FLOW_FUNC_STATUS_EXIT
FLOW_FUNC_ERR_DRV_ERROR = fw.FLOW_FUNC_ERR_DRV_ERROR
FLOW_FUNC_ERR_QUEUE_ERROR = fw.FLOW_FUNC_ERR_QUEUE_ERROR
FLOW_FUNC_ERR_MEM_BUF_ERROR = fw.FLOW_FUNC_ERR_MEM_BUF_ERROR
FLOW_FUNC_ERR_EVENT_ERROR = fw.FLOW_FUNC_ERR_EVENT_ERROR
FLOW_FUNC_ERR_USER_DEFINE_START = fw.FLOW_FUNC_ERR_USER_DEFINE_START
FLOW_FUNC_ERR_USER_DEFINE_END = fw.FLOW_FUNC_ERR_USER_DEFINE_END
MAX_USER_DATA_SIZE = 64
MSG_TYPE_TENSOR_DATA = fw.MsgType.MSG_TYPE_TENSOR_DATA
MSG_TYPE_RAW_MSG = fw.MsgType.MSG_TYPE_RAW_MSG
MSG_TYPE_TORCH_TENSOR_MSG = fw.MsgType.MSG_TYPE_TORCH_TENSOR_MSG
MSG_TYPE_PICKLED_MSG = fw.MsgType.MSG_TYPE_PICKLED_MSG
MSG_TYPE_USER_DEFINE_START = fw.MsgType.MSG_TYPE_USER_DEFINE_START
FLOW_FLAG_EOS = fw.FlowFlag.FLOW_FLAG_EOS
FLOW_FLAG_SEG = fw.FlowFlag.FLOW_FLAG_SEG
DEBUG_LOG = fw.FlowFuncLogType.DEBUG_LOG
RUN_LOG = fw.FlowFuncLogType.RUN_LOG
DEBUG = fw.FlowFuncLogLevel.DEBUG
INFO = fw.FlowFuncLogLevel.INFO
WARN = fw.FlowFuncLogLevel.WARN
ERROR = fw.FlowFuncLogLevel.ERROR
class FlowFuncLogger:
def __init__(self) -> None:
self._impl = fw.FlowFuncLogger()
@staticmethod
def _get_caller_info() -> str:
frame = logging.currentframe()
filename = frame.f_code.co_filename
line_number = frame.f_lineno
func_name = frame.f_code.co_name
location_message = ("[{}:{}][{}]").format(filename, line_number, func_name)
return location_message
def is_log_enable(
self, log_type: fw.FlowFuncLogType, log_level: fw.FlowFuncLogLevel
) -> bool:
if not isinstance(log_type, fw.FlowFuncLogType):
return False
if not isinstance(log_level, fw.FlowFuncLogLevel):
return False
return self._impl.is_log_enable(log_type, log_level)
def get_log_header(self) -> str:
return self._impl.get_log_header()
def debug(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(DEBUG_LOG, DEBUG):
return
try:
user_message = message % args
self._impl.debug_log_debug(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.debug_log_debug(self._get_caller_info(), message)
def error(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(DEBUG_LOG, ERROR):
return
try:
user_message = message % args
self._impl.debug_log_error(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.debug_log_error(self._get_caller_info(), message)
def info(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(DEBUG_LOG, INFO):
return
try:
user_message = message % args
self._impl.debug_log_info(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.debug_log_info(self._get_caller_info(), message)
def warn(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(DEBUG_LOG, WARN):
return
try:
user_message = message % args
self._impl.debug_log_warn(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.debug_log_warn(self._get_caller_info(), message)
def run_error(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(RUN_LOG, WARN):
return
try:
user_message = message % args
self._impl.run_log_error(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.run_log_error(self._get_caller_info(), message)
def run_info(self, message: str, *args: tuple) -> None:
if not self.is_log_enable(RUN_LOG, INFO):
return
try:
user_message = message % args
self._impl.run_log_info(self._get_caller_info(), user_message)
except BaseException as e:
self._impl.run_log_info(self._get_caller_info(), message)
logger = FlowFuncLogger()
class FlowMsg:
def __init__(self, flow_msg: Union[fw.FlowMsg, dw.FlowMsg]) -> None:
self._flow_msg = flow_msg
def __repr__(self):
return self._flow_msg.__repr__()
def get_msg_type(self) -> fw.MsgType:
return self._flow_msg.get_msg_type()
def set_msg_type(self, msg_type) -> None:
return self._flow_msg.set_msg_type(msg_type)
def get_tensor(self):
return (
None
if self._flow_msg.get_tensor() is None
else df.Tensor(self._flow_msg.get_tensor())
)
def get_raw_data(self):
return self._flow_msg.get_raw_data()
def get_ret_code(self) -> int:
return self._flow_msg.get_ret_code()
def set_ret_code(self, ret_code) -> None:
return self._flow_msg.set_ret_code(ret_code)
def get_start_time(self) -> int:
return self._flow_msg.get_start_time()
def set_start_time(self, start_time: int) -> None:
return self._flow_msg.set_start_time(start_time)
def get_end_time(self) -> int:
return self._flow_msg.get_end_time()
def set_end_time(self, end_time: int):
return self._flow_msg.set_end_time(end_time)
def get_flow_flags(self) -> int:
return self._flow_msg.get_flow_flags()
def set_flow_flags(self, flow_flags: int) -> None:
return self._flow_msg.set_flow_flags(flow_flags)
def set_route_label(self, route_label) -> None:
return self._flow_msg.set_route_label(route_label)
def get_transaction_id(self):
return self._flow_msg.get_transaction_id()
def set_transaction_id(self, transaction_id) -> None:
return self._flow_msg.set_transaction_id(transaction_id)
class FlowMsgQueue(queue.Queue):
def __init__(self, flow_msg_queue: fw.FlowMsgQueue) -> None:
super().__init__()
self.mutex = threading.Lock()
self._flow_msg_queue = flow_msg_queue
self.maxsize = self._flow_msg_queue.depth()
def task_done(self):
raise NotImplementedError("This method is not allowed.")
def join(self):
raise NotImplementedError("This method is not allowed.")
def qsize(self):
with self.mutex:
return self._flow_msg_queue.size()
def empty(self):
with self.mutex:
return not self._flow_msg_queue.size()
def full(self):
with self.mutex:
return 0 < self.maxsize <= self._flow_msg_queue.size()
def put(self, item, block=True, timeout=None):
raise NotImplementedError("This method is not allowed.")
def get(self, block=True, timeout=None):
if not block:
ret = self._flow_msg_queue.dequeue(0)
if ret[0] == FLOW_FUNC_SUCCESS:
if int(ret[1].get_msg_type()) == int(MSG_TYPE_TENSOR_DATA):
return None if ret[1].get_tensor() is None else df.Tensor(ret[1])
return utils.convert_flow_msg_to_object(FlowMsg(ret[1]))
raise queue.Empty
elif timeout is None:
return self._get_with_timeout(-1)
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
return self._get_with_timeout(timeout)
def put_nowait(self, item):
raise NotImplementedError("This method is not allowed.")
def get_nowait(self):
return self.get(block=False)
def _get_with_timeout(self, timeout):
ret = self._flow_msg_queue.dequeue(timeout)
if ret[0] == FLOW_FUNC_STATUS_REDEPLOYING:
raise utils.DfAbortException("Redeploying, schedule aborted!", ret[0])
elif ret[0] == FLOW_FUNC_STATUS_EXIT:
raise utils.DfAbortException("Process exiting, schedule aborted!", ret[0])
elif ret[0] == FLOW_FUNC_ERR_TIME_OUT_ERROR:
raise queue.Empty
elif ret[0] == FLOW_FUNC_SUCCESS:
if int(ret[1].get_msg_type()) == int(MSG_TYPE_TENSOR_DATA):
return None if ret[1].get_tensor() is None else df.Tensor(ret[1])
return utils.convert_flow_msg_to_object(FlowMsg(ret[1]))
else:
raise utils.DfException("Queue get item failed", ret[0])
class PyMetaParams:
def __init__(self, meta_params: fw.MetaParams) -> None:
self._meta_params = meta_params
def __repr__(self):
return self._meta_params.__repr__()
def get_name(self) -> str:
return self._meta_params.get_name()
def get_attr_int(self, name: str) -> Tuple[int, int]:
return self._meta_params.get_int64(name)
def get_attr_int_list(self, name: str) -> Tuple[int, List[int]]:
return self._meta_params.get_int64_vector(name)
def get_attr_int_list_list(self, name: str) -> Tuple[int, List[List[int]]]:
return self._meta_params.get_int64_vector_vector(name)
def get_attr_bool(self, name: str) -> Tuple[int, bool]:
return self._meta_params.get_bool(name)
def get_attr_bool_list(self, name: str) -> Tuple[int, List[bool]]:
return self._meta_params.get_bool_list(name)
def get_attr_float(self, name: str) -> Tuple[int, float]:
return self._meta_params.get_float(name)
def get_attr_float_list(self, name: str) -> Tuple[int, List[float]]:
return self._meta_params.get_float_list(name)
def get_attr_tensor_dtype(self, name: str):
dtype = self._meta_params.get_tensor_dtype(name)
if dtype[0] == fw.FLOW_FUNC_SUCCESS:
wrapper_type = dt.get_python_dtype_from_dwrapper_dtype(dtype[1])
return (
fw.FLOW_FUNC_SUCCESS,
dt._dflow_dtype_to_np_dtype.get(wrapper_type, None),
)
else:
return (fw.FLOW_FUNC_ERR_ATTR_NOT_EXITS, None)
def get_attr_tensor_dtype_list(self, name: str):
dtype_list = self._meta_params.get_tensor_dtype_list(name)
result = []
if dtype_list[0] == fw.FLOW_FUNC_SUCCESS:
for i in dtype_list[1]:
wrapper_type = dt.get_python_dtype_from_dwrapper_dtype(i)
result.append(dt._dflow_dtype_to_np_dtype.get(wrapper_type, None))
return (fw.FLOW_FUNC_SUCCESS, result)
else:
return (fw.FLOW_FUNC_ERR_ATTR_NOT_EXITS, result)
def get_attr_str(self, name: str) -> Tuple[int, str]:
return self._meta_params.get_string(name)
def get_attr_str_list(self, name: str) -> Tuple[int, List[str]]:
return self._meta_params.get_string_list(name)
def get_input_num(self) -> int:
return self._meta_params.get_input_number()
def get_output_num(self) -> int:
return self._meta_params.get_output_number()
def get_work_path(self) -> str:
return self._meta_params.get_work_path()
def get_running_device_id(self) -> int:
return self._meta_params.get_running_device_id()
def get_running_instance_id(self) -> int:
return self._meta_params.get_running_instance_id()
def get_running_instance_num(self) -> int:
return self._meta_params.get_running_instance_num()
class AffinityPolicy(Enum):
NO_AFFINITY = fw.AffinityPolicy.NO_AFFINITY
ROW_AFFINITY = fw.AffinityPolicy.ROW_AFFINITY
COL_AFFINITY = fw.AffinityPolicy.COL_AFFINITY
def __init__(self, inner_type):
self.inner_type = inner_type
class BalanceConfig:
def __init__(
self,
row_num: int,
col_num: int,
affinity_policy: AffinityPolicy = AffinityPolicy.NO_AFFINITY,
) -> None:
self._impl = fw.BalanceConfig()
self._impl.set_balance_weight(row_num, col_num)
self._impl.set_affinity_policy(affinity_policy.inner_type)
def set_data_pos(self, pos: List[Tuple[int, int]]):
self._impl.set_data_pos(pos)
def get_inner_config(self):
return self._impl
class MetaRunContext:
def __init__(self, run_context: fw.MetaRunContext) -> None:
self._context = run_context
self._user_data = bytes(MAX_USER_DATA_SIZE)
def __repr__(self):
return self._context.__repr__()
def alloc_tensor_msg(
self, shape: Union[List[int], Tuple[int]], dtype, align: Optional[int] = 64
) -> FlowMsg:
if isinstance(dtype, dt.DType):
dtype = dtype.dtype
elif dtype in df._support_np_dtype:
dtype = dt._np_dtype_to_dflow_dtype.get(np.dtype(dtype)).dtype
else:
raise TypeError(f"The dtype of data is {dtype} is invalid.")
return FlowMsg(self._context.alloc_tensor_msg(shape, dtype, align))
def alloc_raw_data_msg(self, size, align: Optional[int] = 64) -> FlowMsg:
return FlowMsg(self._context.alloc_raw_data_msg(size, align))
def to_flow_msg(self, tensor):
return FlowMsg(self._context.to_flow_msg(tensor._impl))
def set_output(
self,
index,
output: Union[FlowMsg, np.ndarray, fw.FlowMsg],
balance_config: BalanceConfig = None,
) -> int:
if isinstance(output, FlowMsg):
output = output._flow_msg
elif isinstance(output, np.ndarray):
output_b = output
output = self.alloc_tensor_msg(
output.shape,
dt._np_dtype_to_dflow_dtype.get(np.dtype(output.dtype), None),
)
out_np = output.get_tensor().numpy()
out_np[...] = output_b
output = output._flow_msg
if balance_config is None:
return self._context.set_output(index, output)
else:
return self._context.set_output(
index, output, balance_config.get_inner_config()
)
def set_multi_outputs(
self,
index,
outputs: List[Union[FlowMsg, np.ndarray, fw.FlowMsg]],
balance_config: BalanceConfig,
) -> int:
output_msgs = [None] * len(outputs)
for i, output in enumerate(outputs):
if isinstance(output, fw.FlowMsg):
output_msgs[i] = output
elif isinstance(output, FlowMsg):
output_msgs[i] = output._flow_msg
elif isinstance(output, np.ndarray):
output_b = output
output = self.alloc_tensor_msg(
output.shape,
dt._np_dtype_to_dflow_dtype.get(np.dtype(output.dtype), None),
)
out_np = output.get_tensor().numpy()
out_np[...] = output_b
output_msgs[i] = output._flow_msg
else:
raise TypeError(
f"The outputs:{i} type { type(output) } is unsupported."
)
return self._context.set_multi_outputs(
index, output_msgs, balance_config.get_inner_config()
)
def alloc_empty_data_msg(self) -> FlowMsg:
return FlowMsg(self._context.alloc_empty_msg(fw.MsgType.MSG_TYPE_TENSOR_DATA))
def run_flow_model(
self, model_key: str, input_msgs: List[FlowMsg], timeout: int
) -> Tuple[int, List[FlowMsg]]:
wrapper_flow_msg = [msg._flow_msg for msg in input_msgs]
outputs = self._context.run_flow_model(model_key, wrapper_flow_msg, timeout)
return outputs
def get_user_data(self, size: int, offset: int = 0):
ret = self._context.get_user_data(self._user_data, size, offset)
if ret == fw.FLOW_FUNC_SUCCESS:
return self._user_data
else:
return bytearray(MAX_USER_DATA_SIZE)
def raise_exception(self, exception_code: int, user_context_id: int):
return self._context.raise_exception(exception_code, user_context_id)
def get_exception(self) -> tuple:
r"""
tuple(bool, exception_code, user_context_id)
the first element in tuple : False(exception no existed). True(exception existed)
"""
return self._context.get_exception()