import argparse
from collections import OrderedDict
import torch
import torch.onnx
from concern.config import Configurable, Config
def proc_nodes_modile(checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if "module." in k:
name = k.replace("module.", "")
else:
name = k
new_state_dict[name] = v
return new_state_dict
def pth2onnx(model):
input_names = ["actual_input_1"]
output_names = ["output1"]
dynamic_axes = {'actual_input_1': {0: '-1'}, 'output1': {0: '-1'}}
dummy_input = torch.randn(1, 3, 800, 768)
torch.onnx.export(
model,
dummy_input,
"dbnet.onnx",
do_constant_folding=False,
input_names=input_names,
dynamic_axes=dynamic_axes,
output_names=output_names,
autograd_inlining=False,
opset_version=15,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
verbose=False
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='db pth2onnx')
parser.add_argument('exp', type=str)
parser.add_argument('--resume', type=str, help='Resume from checkpoint')
args = parser.parse_args()
args = vars(args)
args = {k: v for k, v in args.items() if v is not None}
conf = Config()
experiment_args = conf.compile(conf.load(args['exp']))['Experiment']
experiment_args.update(cmd=args)
experiment = Configurable.construct_class_from_config(experiment_args)
global_model = experiment.structure.builder.build(torch.device('cpu'))
cpt = torch.load(args['resume'], map_location=torch.device('cpu'))
cpt = proc_nodes_modile(cpt)
global_model.load_state_dict(cpt)
global_model.eval()
pth2onnx(global_model)