import os
import sys
import onnx
import torch
import argparse
import warnings
cur_path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, cur_path + "/..")
warnings.filterwarnings("ignore")
from onnxsim import simplify
from collections import OrderedDict
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.modeling.detector import build_detection_model
def pth2onnx():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg_path", type=str,
default="../configs/retina/retinanet_mask_R-50-FPN_2x_adjust_std011_ms.yaml")
parser.add_argument("--weight_path", type=str, default="./npu_8P_model_0020001.pth")
parser.add_argument("--save_path", type=str, default="./retinamask.onnx")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--simplify", type=bool, default=True)
args = parser.parse_args()
cfg.merge_from_file(args.cfg_path)
cfg.freeze()
onnx_file = args.save_path
device = torch.device('cpu')
model = build_detection_model(cfg)
model = model.to(device)
ckpt = torch.load(args.weight_path, map_location=device)
checkpoints = ckpt['model']
new_checkpoints = OrderedDict()
for k, v in checkpoints.items():
if k.startswith('module'):
k = k[7:]
new_checkpoints[k] = v
model.load_state_dict(new_checkpoints)
model.eval()
dummy_input = torch.randn(args.batch_size, 3, 1344, 1344, dtype=torch.float32)
input_names = ["input"]
output_names = ["bboxs", "labels", "scores", "masks"]
torch.onnx.export(model,
dummy_input,
onnx_file,
input_names=input_names,
output_names=output_names,
opset_version=11,
verbose=False,
enable_onnx_checker=True)
print("************* Convert to ONNX model file SUCCESS! *************")
if args.simplify:
sim_path = onnx_file
onnx_model = onnx.load(onnx_file)
onnx_sim_model, check = simplify(onnx_model, check_n=3)
assert check, "Simplified ONNX model could not be validated"
onnx.save(onnx_sim_model, sim_path)
print('ONNX file simplified!')
if __name__ == '__main__':
pth2onnx()