diff -uprN model_zoo/arcface_onnx.py model_zoo-modified/arcface_onnx.py
@@ -81,7 +81,7 @@ class ArcFaceONNX:
blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
- net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ net_out = self.session.infer(feeds=[blob])[0]
return net_out
def forward(self, batch_data):
diff -uprN model_zoo/attribute.py model_zoo-modified/attribute.py
@@ -80,7 +80,7 @@ class Attribute:
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
- pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
+ pred = self.session.infer(feeds=[blob])[0][0]
if self.taskname=='genderage':
assert len(pred)==3
gender = np.argmax(pred[:2])
diff -uprN model_zoo/landmark.py model_zoo-modified/landmark.py
@@ -88,7 +88,7 @@ class Landmark:
input_size = tuple(aimg.shape[0:2][::-1])
#assert input_size==self.input_size
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
- pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
+ pred = self.session.infer(feeds=[blob])[0][0]
if pred.shape[0] >= 3000:
pred = pred.reshape((-1, 3))
else:
diff -uprN model_zoo/model_zoo.py model_zoo-modified/model_zoo.py
@@ -15,30 +15,35 @@ from .landmark import *
from .attribute import Attribute
from .inswapper import INSwapper
from ..utils import download_onnx
+from ais_bench.infer.interface import InferSession
__all__ = ['get_model']
-class PickableInferenceSession(onnxruntime.InferenceSession):
+class PickableInferenceSession(InferSession):
# This is a wrapper to make the current InferenceSession class pickable.
- def __init__(self, model_path, **kwargs):
- super().__init__(model_path, **kwargs)
+ def __init__(self, model_path, om_path, device):
+ super().__init__(device, om_path)
self.model_path = model_path
+ self.om_path = om_path
+ self.device = device
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
model_path = values['model_path']
- self.__init__(model_path)
+ om_path = values['om_path']
+ device = values['device']
+ self.__init__(model_path, om_path, device)
class ModelRouter:
- def __init__(self, onnx_file):
+ def __init__(self, onnx_file, om_file):
self.onnx_file = onnx_file
+ self.om_file = om_file
- def get_model(self, **kwargs):
- session = PickableInferenceSession(self.onnx_file, **kwargs)
- print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
+ def get_model(self, device):
+ session = PickableInferenceSession(self.onnx_file, self.om_file, device)
inputs = session.get_inputs()
input_cfg = inputs[0]
input_shape = input_cfg.shape
@@ -73,7 +78,7 @@ def get_default_providers():
def get_default_provider_options():
return None
-def get_model(name, **kwargs):
+def get_model(name, om_file, **kwargs):
root = kwargs.get('root', '~/.insightface')
root = os.path.expanduser(root)
model_root = osp.join(root, 'models')
@@ -90,9 +95,10 @@ def get_model(name, **kwargs):
model_file = download_onnx('models', model_file, root=root, download_zip=download_zip)
assert osp.exists(model_file), 'model_file %s should exist'%model_file
assert osp.isfile(model_file), 'model_file %s should be a file'%model_file
- router = ModelRouter(model_file)
- providers = kwargs.get('providers', get_default_providers())
- provider_options = kwargs.get('provider_options', get_default_provider_options())
- model = router.get_model(providers=providers, provider_options=provider_options)
+ assert osp.exists(om_file), 'om_file %s should exist'%om_file
+ assert osp.isfile(model_file), 'om_file %s should be a file'%om_file
+ router = ModelRouter(model_file, om_file)
+ device = kwargs.get('device', 0)
+ model = router.get_model(device)
return model
diff -uprN model_zoo/retinaface.py model_zoo-modified/retinaface.py
@@ -149,7 +149,7 @@ class RetinaFace:
kpss_list = []
input_size = tuple(img.shape[0:2][::-1])
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
- net_outs = self.session.run(self.output_names, {self.input_name : blob})
+ net_outs = self.session.infer(feeds=[blob])
input_height = blob.shape[2]
input_width = blob.shape[3]