import sys
from collections import OrderedDict
import torch
sys.path.append(r"PyTorch-GAN/implementations/dcgan")
from dcgan import Generator
def proc_nodes_module(checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if "module." in k:
name = k.replace("module.", "")
else:
name = k
new_state_dict[name] = v
return new_state_dict
def pth2onnx(input_file, output_file):
shape = (1, 100, 1, 1)
img_size = 32
latent_dim = 100
channels = 1
model = Generator(img_size, latent_dim, channels)
checkpoint = torch.load(input_file, map_location='cpu')['G']
checkpoint = proc_nodes_module(checkpoint)
model.load_state_dict(checkpoint)
model.eval()
input_names = ["noise"]
output_names = ["image"]
dynamic_axes = {'noise': {0: '-1'}, 'image': {0: '-1'}}
dummy_input = torch.randn(shape)
torch.onnx.export(model, dummy_input, output_file, input_names=input_names, dynamic_axes=dynamic_axes,
output_names=output_names, opset_version=11, verbose=True)
if __name__ == "__main__":
pth2onnx(sys.argv[1], sys.argv[2])
print('done! convert success!')