@@ -6,7 +6,7 @@ import warnings
import numpy as np
import onnx
from onnx import (helper, GraphProto)
-from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
+from onnx.helper import np_dtype_to_tensor_dtype
from . import (BaseGraph, OnnxNode, PLACEHOLDER, INITIALIZER)
from .node import OnnxNode
@@ -56,8 +56,8 @@ class OnnxGraph(BaseGraph):
except Exception as e:
print(e)
raise RuntimeError(
- f'{dtype} is illegal, only support basic data type: {NP_TYPE_TO_TENSOR_TYPE.keys()}')
- elem_type = NP_TYPE_TO_TENSOR_TYPE[dtype]
+ f'{dtype} is illegal, only support basic data type: {np_dtype_to_tensor_dtype.keys()}')
+ elem_type = np_dtype_to_tensor_dtype(dtype)
node = self._model.graph.input.add()
node.CopyFrom(helper.make_tensor_value_info(name, elem_type, shape))
ph = OnnxNode(node)
@@ -69,7 +69,7 @@ class OnnxGraph(BaseGraph):
assert name not in self._all_ops_name, f'The ({name}) has been existed in graph, please change the node.name.'
node = self._model.graph.initializer.add()
node.CopyFrom(helper.make_tensor(name,
- NP_TYPE_TO_TENSOR_TYPE[value.dtype],
+ np_dtype_to_tensor_dtype(value.dtype),
value.shape,
value.flatten().tolist()))
init = OnnxNode(node)
@@ -3,7 +3,7 @@ import numpy as np
from onnx import (NodeProto, TensorProto, ValueInfoProto,
TensorShapeProto, AttributeProto,
helper, numpy_helper)
-from onnx.mapping import (TENSOR_TYPE_TO_NP_TYPE, NP_TYPE_TO_TENSOR_TYPE)
+from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype
from . import (BaseNode, PLACEHOLDER, INITIALIZER)
from .utils.log import typeassert
@@ -75,7 +75,7 @@ class OnnxNode(BaseNode):
@property
def dtype(self):
tensor = self._node.type.tensor_type
- return TENSOR_TYPE_TO_NP_TYPE[tensor.elem_type]
+ return tensor_dtype_to_np_dtype(tensor.elem_type)
@dtype.setter
def dtype(self, data_type):
@@ -84,9 +84,9 @@ class OnnxNode(BaseNode):
except Exception as e:
print(e)
raise RuntimeError(
- f'{data_type} is illegal, only support basic data type: {NP_TYPE_TO_TENSOR_TYPE.keys()}')
+ f'{data_type} is illegal, only support basic data type: {np_dtype_to_tensor_dtype.keys()}')
tensor = self._node.type.tensor_type
- tensor.elem_type = NP_TYPE_TO_TENSOR_TYPE[dtype]
+ tensor.elem_type = np_dtype_to_tensor_dtype(dtype)
@property
def shape(self):