import argparse
import torch
import os
class parse_args():
def __init__(self, isTrain=True, isTest=False):
self.isTrain = isTrain
self.isTest = isTest
self.parser = argparse.ArgumentParser(description='Pytorch CycleGAN training')
def initialize(self):
parser = self.parser
parser.add_argument('--npu', default=False, help='whether to use npu to fastern training')
parser.add_argument('--pu_ids', type=str, default='1',
help='gpu ids(npu ids): e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--dataroot', type=str, default='./datasets/maps',
help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='maps_cycle_gan',
help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--model', type=str, default='cycle_gan',
help='chooses which model to use. [cycle_gan]')
parser.add_argument('--input_nc', type=int, default=3,
help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3,
help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic',
help='specify discriminator architecture [basic | n_layers | pixel]. '
'The basic model is a 70x70 PatchGAN. n_layers allows you to'
' specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks',
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--batch_size', type=int, default=1,
help='batch_size')
parser.set_defaults(no_dropout=True)
parser.add_argument('--model_ga_path', type=str,
default='./checkpoints/maps_cycle_gan/latest_net_G_A.pth',
help='path for modelga')
parser.add_argument('--model_gb_path', type=str,
default='./checkpoints/maps_cycle_gan/latest_net_G_B.pth',
help='path for modelga')
parser.add_argument('--onnx_path', type=str,
default='./onnxmodel/',
help='path for modelga')
parser.add_argument('--model_ga_onnx_name', type=str,
default='model_Ga.onnx',
help='onnx name for modelga')
parser.add_argument('--model_gb_onnx_name', type=str,
default='model_Gb.onnx',
help='onnx for modelgb')
parser.add_argument('--gpuPerformance', type=str,
default='./gpuPerformance/',
help='file for t4 test result ')
parser.add_argument('--npu_bin_file', type=str,
default='./result/dumpOutput_device0/',
help='npu bin ')
parser.add_argument('--om_save',
action="store_true",
help='save om results ')
parser.add_argument('--onnx_save',
action="store_true",
help='save onnx results ')
parser.set_defaults(model='test')
parser.set_defaults(load_size=parser.get_default('crop_size'))
parser = parser.parse_args()
parser.process_device_map = self.device_id_to_process_device_map(parser.pu_ids)
return parser
def device_id_to_process_device_map(self, device_list):
devices = device_list.split(",")
devices = [int(x) for x in devices]
devices.sort()
process_device_map = dict()
for process_id, device_id in enumerate(devices):
process_device_map[process_id] = device_id
return process_device_map
def change_parser(self, isTrain=True, isTest=False):
self.isTest = isTest
self.isTrain = isTrain
self.parser = None
return self.initialize()
def printParser(self):
pasers = self.parser.parse_args()
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(pasers).items()):
comment = ''
default = self.parser.get_default(k)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)