""" BigGAN: The Authorized Unofficial PyTorch release
Code by A. Brock and A. Andonian
This code is an unofficial reimplementation of
"Large-Scale GAN Training for High Fidelity Natural Image Synthesis,"
by A. Brock, J. Donahue, and K. Simonyan (arXiv 1809.11096).
Let's go.
"""
import os
import time
import torch
import torch.distributed as dist
import inception_utils
import train_fns
import utils
def device_id_to_process_device_map(device_list):
devices = device_list.split(",")
devices = [int(x) for x in devices]
devices.sort()
process_device_map = dict()
for process_id, device_id in enumerate(devices):
process_device_map[process_id] = device_id
return process_device_map
def get_device_name(device_type, device_order):
if device_type == 'npu':
device_name = 'npu:{}'.format(device_order)
else:
device_name = 'cuda:{}'.format(device_order)
return device_name
def profiling(data_loader, G, D, train, config):
print("profiling mode ...")
G.train()
D.train()
for i, (x, y) in enumerate(data_loader):
if config['D_fp16']:
x, y = x.to(config['loc']).half(), y.to(config['loc'])
else:
x, y = x.to(config['loc']), y.to(config['loc'])
if i < 5:
print("iter: ", i)
train(x, y)
else:
if config['device'] == 'npu':
with torch.autograd.profiler.profile(use_npu=True) as prof:
train(x, y)
else:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
train(x, y)
break
prof.export_chrome_trace("%s.prof" % config['device'])
def run(gpu, ngpus_per_node, config):
config['gpu'] = config['process_device_map'][gpu]
if config['distributed']:
print("use distributed training... gpu:", config['gpu'])
if config['device'] == 'npu':
dist.init_process_group(backend=config['dist_backend'],
world_size=config['world_size'],
rank=config['rank'])
else:
dist.init_process_group(backend=config['dist_backend'],
init_method=config['dist_url'],
world_size=config['world_size'],
rank=config['rank'])
print('rank: {} / {}'.format(config['rank'], config['world_size']))
device_loc = get_device_name(config['device'], config['gpu'])
config['loc'] = device_loc
print('set_device ', device_loc)
if config['device'] == 'npu':
torch.npu.set_device(device_loc)
else:
torch.cuda.set_device(config['gpu'])
torch.backends.cudnn.benchmark = True
model = __import__(config['model'])
G = model.Generator(**config).to(device_loc)
D = model.Discriminator(**config).to(device_loc)
if config['ema']:
print('Preparing EMA for G with decay of {}'.format(config['ema_decay']))
G_ema = model.Generator(**{**config, 'skip_init': True, 'no_optim': True}).to(device_loc)
ema = utils.ema(G, G_ema, config['ema_decay'], config['ema_start'])
else:
G_ema, ema = None, None
if config['distributed']:
config['batch_size'] = int(config['batch_size'] / config['world_size'])
config['num_workers'] = int((config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)
if config['G_fp16']:
print('Casting G to float16...')
G = G.half()
if config['ema']:
G_ema = G_ema.half()
if config['D_fp16']:
print('Casting D to fp16...')
D = D.half()
GD = model.G_D(G, D)
state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'save_best_num': 0,
'best_IS': 0, 'best_FID': 999999, 'config': config}
if config['resume']:
print('Loading weights...device:', device_loc)
utils.load_weights(G, D, state_dict,
config['weights_root'], config['experiment_name'],
config['load_weights'] if config['load_weights'] else None,
G_ema if config['ema'] else None)
if config['distributed']:
GD = torch.nn.parallel.DistributedDataParallel(GD, device_ids=[config['gpu']], find_unused_parameters=True)
if not config['distributed'] or (config['distributed'] and config['gpu'] == config['process_device_map'][0]):
test_metrics_fname = '%s/%s_log.jsonl' % (config['logs_root'],
config['experiment_name'])
train_metrics_fname = '%s/%s' % (config['logs_root'], config['experiment_name'])
print('Inception Metrics will be saved to {}'.format(test_metrics_fname))
test_log = utils.MetricsLogger(test_metrics_fname,
reinitialize=(not config['resume']))
print('Training Metrics will be saved to {}'.format(train_metrics_fname))
train_log = utils.MyLogger(train_metrics_fname,
reinitialize=(not config['resume']),
logstyle=config['logstyle'])
else:
test_log = None
train_log = None
if not config['distributed'] or (config['distributed'] and config['gpu'] == config['process_device_map'][0]):
utils.write_metadata(config['logs_root'], config['experiment_name'], config, state_dict)
D_batch_size = (config['batch_size'] * config['num_D_steps'] * config['num_D_accumulations'])
loaders = utils.get_data_loaders(**{**config, 'batch_size': D_batch_size,
'start_itr': state_dict['itr']})
if config['distributed']:
train_sampler = loaders[0]
loader = loaders[1]
else:
loader = loaders[0]
G_batch_size = max(config['G_batch_size'], config['batch_size'])
z_, y_ = utils.prepare_z_y(G_batch_size, G.dim_z, config['n_classes'],
device=device_loc, fp16=config['G_fp16'])
fixed_z, fixed_y = utils.prepare_z_y(G_batch_size, G.dim_z,
config['n_classes'], device=device_loc,
fp16=config['G_fp16'])
fixed_z.sample_()
fixed_y.sample_()
if config['which_train_fn'] == 'GAN':
train = train_fns.GAN_training_function(G, D, GD, z_, y_,
ema, state_dict, config)
else:
train = train_fns.dummy_training_function()
if config['prof']:
profiling(loader, G, D, train, config)
return
print('Beginning training at epoch %d...' % state_dict['epoch'])
start_time = time.time()
total = config['num_epochs'] * len(loader)
for epoch in range(state_dict['epoch'], config['num_epochs']):
if config['distributed']:
train_sampler.set_epoch(epoch)
batch_time = utils.AverageMeter('Time', ':6.3f')
data_time = utils.AverageMeter('Data', ':6.3f')
end = time.time()
for i, (x, y) in enumerate(loader):
data_time.update(time.time() - end)
state_dict['itr'] += 1
G.train()
D.train()
if config['ema']:
G_ema.train()
if config['D_fp16']:
x, y = x.to(device_loc).half(), y.to(device_loc)
else:
x, y = x.to(device_loc), y.to(device_loc)
metrics = train(x, y)
cost_time = time.time() - end
batch_time.update(cost_time)
end = time.time()
metrics['data_val'] = data_time.val
metrics['data_avg'] = data_time.avg
metrics['batch_val'] = batch_time.val
metrics['batch_avg'] = batch_time.avg
metrics['FPS'] = D_batch_size * config['world_size'] / batch_time.avg if batch_time.avg else 0
if not config['distributed'] or (
config['distributed'] and config['gpu'] == config['process_device_map'][0]):
train_log.log(itr=int(state_dict['itr']), epoch=epoch, **metrics)
if (config['sv_log_interval'] > 0) and (not (state_dict['itr'] % config['sv_log_interval'])):
train_log.log(itr=int(state_dict['itr']), epoch=epoch,
**{**utils.get_SVs(G, 'G'), **utils.get_SVs(D, 'D')})
if config['pbar'] == 'mine':
print(', '.join(
["Epoch: %d" % epoch,
'itr/total: %d/%d' % (state_dict['itr'], total),
"time: %d:%02d" % tuple(divmod(time.time() - start_time, 60))]
+ ['%s : %+4.3f' % (key, metrics[key]) for key in metrics]), end=' ')
print()
if not (state_dict['itr'] % config['save_every']):
if config['G_eval_mode']:
G.eval()
if config['ema']:
G_ema.eval()
train_fns.save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
state_dict, config, config['experiment_name'], device=device_loc)
if config['cann_prof']:
return
if 0 < config['stop_iter'] == state_dict['itr']:
return
state_dict['epoch'] += 1
def main():
parser = utils.prepare_parser()
config = vars(parser.parse_args())
config['process_device_map'] = device_id_to_process_device_map(config['device_list'])
os.environ['MASTER_ADDR'] = config['addr']
os.environ['MASTER_PORT'] = '29688'
utils.seed_rng(config['seed'])
if config['device'] == 'npu':
ngpus_per_node = len(config['process_device_map'])
else:
if config['gpu'] is None:
ngpus_per_node = len(config['process_device_map'])
else:
ngpus_per_node = 1
config['world_size'] = ngpus_per_node * config['world_size']
config['distributed'] = config['world_size'] > 1
config['resolution'] = utils.imsize_dict[config['dataset']]
config['n_classes'] = utils.nclass_dict[config['dataset']]
config['G_activation'] = utils.activation_dict[config['G_nl']]
config['D_activation'] = utils.activation_dict[config['D_nl']]
if config['resume']:
print('Skipping initialization for training resumption...')
config['skip_init'] = True
config = utils.update_config_roots(config)
experiment_name = (config['experiment_name'] if config['experiment_name']
else utils.name_from_config(config))
config['experiment_name'] = experiment_name
run(config['rank'], ngpus_per_node, config)
if __name__ == '__main__':
main()