"""Trainer
"""
import argparse
import importlib
import logging
import os
import sys
from time import *
import torch
from torch.utils.data import DataLoader
import torch.npu
from apex.parallel import convert_syncbn_model
from apex import amp
import random
import numpy as np
import common.meters
import common.modes
def unPackPad(lr_image, width, height):
lr_image = lr_image.reshape((lr_image.shape[1], lr_image.shape[2], lr_image.shape[3]))
lr_height = lr_image.shape[2]
lr_width = lr_image.shape[1]
pad_w = int((lr_width - width) / 2)
pad_h = int((lr_height - height) / 2)
lr_image_unpad = lr_image[:, pad_w:2040 - pad_w, pad_h:2040 - pad_h]
lr_image_unpad = lr_image_unpad.reshape(1, lr_image_unpad.shape[0], lr_image_unpad.shape[1],
lr_image_unpad.shape[2])
return torch.from_numpy(lr_image_unpad)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
help='Dataset name.',
default=None,
type=str,
required=True,
)
parser.add_argument(
'--model',
help='Model name.',
default=None,
type=str,
required=True,
)
parser.add_argument(
'--job_dir',
help='Directory to write checkpoints and export models.',
default=None,
type=str,
required=True,
)
parser.add_argument(
'--ckpt',
help='File path to load checkpoint.',
default=None,
type=str,
)
parser.add_argument(
'--override_epoch',
help='Override epoch number when loading from checkpoint.',
default=None,
type=int,
)
parser.add_argument(
'--eval_only',
default=False,
action='store_true',
help='Running evaluation only.',
)
parser.add_argument(
'--eval_datasets',
help='Dataset names for evaluation.',
default=None,
type=str,
nargs='+',
)
parser.add_argument(
'--save_checkpoints_epochs',
help='Number of epochs to save checkpoint.',
default=1,
type=int)
parser.add_argument(
'--keep_checkpoints',
help='Keepining intermediate checkpoints.',
default=False,
action='store_true')
parser.add_argument(
'--train_epochs',
help='Number of epochs to run training totally.',
default=10,
type=int)
parser.add_argument(
'--log_steps',
help='Number of steps for training logging.',
default=100,
type=int)
parser.add_argument(
'--random_seed',
help='Random seed for TensorFlow.',
default=1234,
type=int)
parser.add_argument(
'--opt_level',
help='Number of GPUs for experiments.',
default='O2',
type=str)
parser.add_argument(
'--sync_bn',
default=False,
action='store_true',
help='Enabling apex sync BN.')
parser.add_argument(
'-v',
'--verbose',
action='count',
default=1,
help='Increasing output verbosity.',
)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--node_rank', default=0, type=int)
args, _ = parser.parse_known_args()
logging.basicConfig(
level=[logging.WARNING, logging.INFO, logging.DEBUG][args.verbose],
format='%(asctime)s:%(levelname)s:%(message)s')
dataset_module = importlib.import_module(
'datasets.' + args.dataset if args.dataset else 'datasets')
dataset_module.update_argparser(parser)
model_module = importlib.import_module('models.' +
args.model if args.model else 'models')
model_module.update_argparser(parser)
params = parser.parse_args()
logging.critical(params)
random.seed(params.random_seed)
np.random.seed(params.random_seed)
torch.manual_seed(params.random_seed)
os.environ['PYTHONHASHSEED'] = str(params.random_seed)
torch.backends.cudnn.benchmark = True
device = torch.device('npu:{}'.format(params.local_rank))
print(device)
torch.npu.set_device(device)
params.distributed = False
params.master_proc = True
cardNumber = 1
if 'WORLD_SIZE' in os.environ:
params.distributed = int(os.environ['WORLD_SIZE']) > 1
cardNumber = int(os.environ['WORLD_SIZE'])
if params.distributed:
torch.distributed.init_process_group(backend='hccl', rank=params.local_rank,
world_size=int(os.environ['WORLD_SIZE']))
if params.node_rank or params.local_rank:
params.master_proc = False
train_dataset = dataset_module.get_dataset(common.modes.TRAIN, params)
if params.eval_datasets:
eval_datasets = []
for eval_dataset in params.eval_datasets:
eval_dataset_module = importlib.import_module('datasets.' + eval_dataset)
eval_datasets.append(
(eval_dataset,
eval_dataset_module.get_dataset(common.modes.EVAL, params)))
else:
eval_datasets = [(params.dataset,
dataset_module.get_dataset(common.modes.EVAL, params))]
if params.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset)
eval_sampler = None
params.train_batch_size //= cardNumber
else:
train_sampler = None
eval_sampler = None
train_data_loader = DataLoader(
dataset=train_dataset,
num_workers=params.num_data_threads,
batch_size=params.train_batch_size,
shuffle=(train_sampler is None),
drop_last=True,
pin_memory=False,
sampler=train_sampler,
)
eval_data_loaders = [(data_name,
DataLoader(
dataset=dataset,
num_workers=params.num_data_threads,
batch_size=params.eval_batch_size,
shuffle=False,
drop_last=False,
pin_memory=False,
sampler=eval_sampler,
)) for data_name, dataset in eval_datasets]
model, criterion, optimizer, lr_scheduler, metrics = model_module.get_model_spec(params)
model = model.to(device)
criterion = criterion.to(device)
model, optimizer = amp.initialize(
model, optimizer, opt_level=params.opt_level, verbosity=params.verbose, loss_scale=128.0, combine_grad=True)
if params.ckpt or os.path.exists(os.path.join(params.job_dir, 'latest.pth')):
checkpoint = torch.load(
params.ckpt or os.path.join(params.job_dir, 'latest.pth'),
device)
try:
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
except RuntimeError as e:
logging.critical(e)
latest_epoch = checkpoint['epoch']
logging.critical('Loaded checkpoint from epoch {}.'.format(latest_epoch))
if params.override_epoch is not None:
latest_epoch = params.override_epoch
logging.critical('Overrode epoch number to {}.'.format(latest_epoch))
else:
latest_epoch = 0
if params.distributed:
if params.sync_bn:
model = convert_syncbn_model(model)
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[params.local_rank],
broadcast_buffers=False)
def train(epoch):
if params.distributed:
train_sampler.set_epoch(epoch)
loss_meter = common.meters.AverageMeter()
time_meter = common.meters.TimeMeter()
model.train()
for batch_idx, (data, target) in enumerate(train_data_loader):
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if batch_idx == 10:
time_meter.reset()
if batch_idx % params.log_steps == 0:
loss_meter.update(loss.item(), data.size(0))
time_meter.update_count(batch_idx + 1)
logging.info(
'Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tSpeed: {:.6f} seconds/batch\tFPS:{FPS:.3f}'
.format(epoch, batch_idx * len(data),
len(train_data_loader) * len(data),
100. * batch_idx / len(train_data_loader), loss.item(),
time_meter.avg,
FPS=train_data_loader.batch_size * cardNumber / time_meter.avg))
logging.critical('Train epoch {} finished.\tLoss: {:.6f}'.format(
epoch, loss_meter.avg))
def evaluate():
with torch.no_grad():
model.eval()
for eval_data_name, eval_data_loader in eval_data_loaders:
metric_meters = {}
for metric_name in metrics.keys():
metric_meters[metric_name] = common.meters.AverageMeter()
time_meter = common.meters.TimeMeter()
for data, target in eval_data_loader:
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(data)
output = unPackPad(output.cpu().numpy(), target.shape[2], target.shape[3])
output = output.to(device, non_blocking=True)
for metric_name in metrics.keys():
metric_meters[metric_name].update(
metrics[metric_name](output, target).item(), data.size(0))
time_meter.update(data.size(0))
for metric_name in metrics.keys():
logging.critical('Eval {}: Average {}: {:.4f}'.format(
eval_data_name, metric_name, metric_meters[metric_name].avg))
logging.critical('Eval {}: Average Speed: {:.6f} seconds/sample'.format(
eval_data_name, time_meter.avg))
if params.eval_only:
evaluate()
exit()
for epoch in range(latest_epoch + 1, params.train_epochs + 1):
train(epoch)
lr_scheduler.step()
if epoch % params.save_checkpoints_epochs == 0:
if params.master_proc:
if not os.path.exists(params.job_dir):
os.makedirs(params.job_dir)
torch.save(
{
'model_state_dict':
model.module.state_dict()
if params.distributed else model.state_dict(),
'optimizer_state_dict':
optimizer.state_dict(),
'lr_scheduler_state_dict':
lr_scheduler.state_dict(),
'epoch':
epoch,
}, os.path.join(params.job_dir, 'epoch_{}.pth'.format(epoch)))
latest_path = os.path.join(params.job_dir, 'latest.pth')
if os.path.exists(latest_path):
if not params.keep_checkpoints:
os.remove(os.path.join(params.job_dir, os.readlink(latest_path)))
os.remove(latest_path)
evaluate()