"""
op context
"""
import functools
import threading
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from .op_info import OpInfo
_contexts = {}
_custom_contexts = {}
def _get_contexts():
return _contexts.setdefault(threading.currentThread().ident, [])
class OpContext:
"""
Op Context
"""
def __init__(self, op_mode=None):
"""
:param op_mode: dynamic, static, pre-static
"""
self._op_mode = op_mode
self._graph_op_info = None
self._op_info = []
self._compile_info = {}
self._build_res = {}
self._buffer_manager = None
self._build_type = None
self._missing_support_info = None
self._build_json_result = {}
self._workspaces = []
self._additional_params = {}
self._custom_context = {}
for _name, _clz in _custom_contexts.items():
self._custom_context[_name] = _clz()
def __enter__(self):
_get_contexts().append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_get_contexts().pop()
def set_op_mode(self, op_mode):
"""
:param op_mode: dynamic, static, pre-static
:return:
"""
self._op_mode = op_mode
def get_op_mode(self):
"""
:return:
"""
return self._op_mode
def add_op_info(self, op_info):
"""
:param op_info:
:return:
"""
self._op_info.append(op_info)
def get_op_info(self, name=None):
"""
:param name: If none, return all op info.
:return:
"""
if name is None:
return self._op_info
for x in self._op_info:
if x.op_name == name:
return x
return None
def set_graph_op_info(self, graph_op_info):
"""
:param op_info:
:return:
"""
self._graph_op_info = graph_op_info
def get_graph_op_info(self):
"""
:return:
"""
return self._graph_op_info
def add_compile_info(self, k, v):
"""
:param k:
:param v:
:return:
"""
self._compile_info[k] = v
def get_compile_info(self, k=None):
"""
:param k: If none, return all compile info.
:return:
"""
if k is None:
return self._compile_info
return self._compile_info.get(k)
def set_compile_info(self, compile_info):
self._compile_info = compile_info
def add_build_res(self, k, v):
"""
:param k:
:param v:
:return:
"""
self._build_res[k] = v
def get_build_res(self, k=None):
"""
:param k: If none, return all build res.
:return:
"""
if k is None:
return self._build_res
return self._build_res.get(k)
def get_buffer_manager(self):
"""
:return:
"""
return self._buffer_manager
def set_buffer_manager(self, buffer_manager):
"""
:param buffer_manager:
:return:
"""
self._buffer_manager = buffer_manager
def get_custom_context(self, k):
"""
:param k:
:return:
"""
return self._custom_context.get(k)
def get_build_type(self):
"""
:return:
"""
return self._build_type
def set_build_type(self, build_type):
"""
:param build_type:
:return:
"""
self._build_type = build_type
def get_missing_support_info(self):
"""
:return:
"""
return self._missing_support_info
def set_missing_support_info(self, missing_support_info):
"""
:param missing_support_info:
:return:
"""
self._missing_support_info = missing_support_info
def add_build_json_result(self, k, v):
"""
:param k:
:param v:
:return:
"""
self._build_json_result[k] = v
def get_build_json_result(self, k=None):
"""
:param k: If none, return all build json result.
:return:
"""
if k is None:
return self._build_json_result
return self._build_json_result.get(k)
def add_workspace(self, name, size=-1, addr_type=0):
"""
:param name:
:param size:
:param type:
:return:
"""
self._workspaces.append((name, size, addr_type))
def get_workspaces(self):
"""
:return:
"""
return self._workspaces
def add_addition(self, key, value):
"""
:param key:
:param value:
:return:
"""
self._additional_params[key] = value
def get_addition(self, key):
"""
:param key:
:return:
"""
return self._additional_params.get(key)
def get_context():
"""
:return:
"""
return _get_contexts()[-1] if _get_contexts() else None
def get_op_mode():
"""
:return:
"""
op_context_obj = get_context()
return op_context_obj.get_op_mode() if op_context_obj else None
def in_dynamic():
"""
:return:
"""
return get_op_mode() == "dynamic"
def register_custom_context(name):
"""
:param name:
:return:
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
_custom_contexts[name] = wrapper
return wrapper
return decorator