from __future__ import print_function
import sys
sys.path.insert(0, './AlignedReID-Re-Production-Pytorch')
import torch
from aligned_reid.model.Model import Model
class Config(object):
def __init__(self):
self.local_conv_out_channels = 128
def main(pth_path, out_name):
cfg = Config()
model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
num_classes=751)
if pth_path != '':
model.load_state_dict(torch.load(pth_path, map_location='cpu')['state_dicts'][0])
model.eval()
input_names = ["image"]
output_names = ["global_feat", "local_feat", "logits"]
dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}}
dummy_input = torch.randn(1, 3, 256, 128)
torch.onnx.export(
model, dummy_input, out_name, input_names=input_names, output_names=output_names, verbose=True,
opset_version=11, dynamic_axes=dynamic_axes)
if __name__ == '__main__':
pth_path = sys.argv[1]
out_name = sys.argv[2]
main(pth_path, out_name)