import os
import torch
import torch.onnx
from collections import OrderedDict
from parse import parse_args
import networks
class load_networks():
def __init__(self, opt):
self.opt = opt
self.gpu = 0
self.netG_A = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.netG,
self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain,
self.gpu)
self.netG_B = networks.define_G(self.opt.output_nc, self.opt.input_nc, self.opt.ngf, self.opt.netG,
self.opt.norm, not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain,
self.gpu)
if (opt.npu == False):
self.device = torch.device('cuda:{}'.format(self.gpu))
else:
self.device = torch.device("cpu")
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys):
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def proc_nodes_module(self, checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if "module." in k:
name = k.replace("module.", "")
else:
name = k
new_state_dict[name] = v
return new_state_dict
def loadnetworks(self, net, load_path):
state_dict = torch.load(load_path, map_location=torch.device('cpu'))
state_dict = self.proc_nodes_module(state_dict)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
for key in list(state_dict.keys()):
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
return net
def get_networks(self, load_patha, load_pathb):
model_Ga = self.loadnetworks(self.netG_A, load_patha)
model_Gb = self.loadnetworks(self.netG_B, load_pathb)
return model_Ga, model_Gb
def main():
paser = parse_args(True, True)
opt = paser.initialize()
lnetworks = load_networks(opt)
model_Ga, model_Gb = lnetworks.get_networks(opt.model_ga_path, opt.model_gb_path)
device_cpu = torch.device("cpu")
model_Ga = model_Ga.to(device_cpu)
model_Gb = model_Gb.to(device_cpu)
dummy_input = torch.randn(1, 3, 256, 256)
input_names = ["img_sat_maps"]
output_names = ["maps"]
dynamic_axes = {'img_sat_maps': {0: '-1'}, 'maps': {0: '-1'}}
input_names1 = ["img_maps_sat"]
output_names1 = ["sat"]
dynamic_axes1 = {'img_maps_sat': {0: '-1'}, 'sat': {0: '-1'}}
if (os.path.exists(opt.onnx_path) == False):
os.makedirs(opt.onnx_path)
torch.onnx.export(model_Ga, dummy_input, f=opt.onnx_path + '/' + opt.model_ga_onnx_name, verbose=True, training=False, \
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, opset_version=11)
torch.onnx.export(model_Gb, dummy_input, f=opt.onnx_path + '/' + opt.model_gb_onnx_name, verbose=True, training=False, \
dynamic_axes=dynamic_axes1, input_names=input_names1, output_names=output_names1,
opset_version=11)
if __name__ == '__main__':
main()