import copy
import datetime
import os
import time
import sys
import torch
if torch.__version__ >= "1.8":
import torch_npu
import torch.utils.data
from torch import nn
import torch.npu
import random
import numpy as np
import utils
from models import RDN
from datasets import TrainDataset, EvalDataset
import argparse
import torch.multiprocessing as mp
from utils import AverageMeter
try:
import apex
from apex import amp
except ImportError:
amp = None
def flush_print(func):
def new_print(*args, **kwargs):
func(*args, **kwargs)
sys.stdout.flush()
return new_print
print = flush_print(print)
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args):
model.train()
batchtime = AverageMeter(start_count_index=3)
cond = 4000 // (args.batch_size * args.gpus)
for i, (image, target) in enumerate(data_loader):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
batch_size = image.shape[0]
if i < cond:
batchtime.update(time.time() - start_time)
if args.is_master_node:
print("Epoch {} step {},loss :{},img/s :{},time :{}".format(epoch, i, loss, batch_size / batchtime.val,
batchtime.val))
if args.is_master_node:
print('epoch:{} FPS: {:.3f}'.format(epoch, args.gpus * args.batch_size / batchtime.avg))
def evaluate(model, criterion, data_loader, device, args):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, args, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(image)
loss = criterion(output, target)
epoch_psnr = utils.calc_psnr(output, target, args.scale, device)
batch_size = image.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['epoch_psnr'].update(epoch_psnr.item(), n=batch_size)
metric_logger.synchronize_between_processes(device)
print('epoch_psnr={} '.format(metric_logger.epoch_psnr.global_avg))
return metric_logger.epoch_psnr.global_avg
def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
def main(args):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir:
utils.mkdir(args.output_dir)
if args.distributed:
mp.spawn(main_worker, nprocs=args.gpus, args=(args,))
else:
main_worker(args.gpus, args)
def main_worker(nprocs, args):
local_rank = 0
if args.distributed:
torch.distributed.init_process_group(backend="hccl",
init_method='env://',
world_size=args.world_size * args.gpus,
rank=nprocs)
local_rank = torch.distributed.get_rank()
args.is_master_node = not args.distributed or local_rank == 0
if args.is_master_node:
print(args)
args.device_id = args.device_id + local_rank
print('device_id=', args.device_id)
device = torch.device(f'npu:{args.device_id}')
torch.npu.set_device(device)
print("Loading data")
traindir = os.path.join(args.data_path, 'train')
valdir = os.path.join(args.data_path, 'val')
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
dataset = TrainDataset(args.train_file, patch_size=args.patch_size, scale=args.scale)
if args.is_master_node and args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = EvalDataset(args.eval_file)
if args.is_master_node and args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, num_workers=0, pin_memory=True)
if args.is_master_node:
print("Creating model")
model = RDN(scale_factor=args.scale,
num_channels=3,
num_features=args.num_features,
growth_rate=args.growth_rate,
num_blocks=args.num_blocks,
num_layers=args.num_layers)
if args.pretrained:
checkpoint = torch.load(args.pretrained_weight_path, map_location='cpu')
if 'module.' in list(checkpoint['model'].keys())[0]:
checkpoint['model'] = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
model.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.L1Loss().to(device)
optimizer = apex.optimizers.NpuFusedAdam(
model.parameters(), lr=args.lr)
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level,
loss_scale=args.loss_scale_value,
combine_grad=True)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.device_id], broadcast_buffers=False)
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device, args=args)
return
if args.is_master_node:
print("Start training")
best_weights = copy.deepcopy(model_without_ddp.state_dict())
best_epoch = 0
best_psnr = 0.0
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step()
epoch_psnr = evaluate(model, criterion, data_loader_test, device=device, args=args)
if (epoch + 1) % 40 == 0 and args.is_master_node and args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))
if args.is_master_node and epoch_psnr > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr
best_weights = copy.deepcopy(model_without_ddp.state_dict())
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
if args.is_master_node:
print('Training time {}'.format(total_time_str))
print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
if args.epochs > 100:
checkpoint = {
'model': best_weights,
'psnr': best_psnr,
'epoch': best_epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'best.pth'))
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch Classification Training')
parser.add_argument('--data_path', default='', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device_id', default=0, type=int, help='device id')
parser.add_argument('--batch-size', default=16, type=int)
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--lr_step_size', default=600, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr_gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='outputs', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true"
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true"
)
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true"
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true"
)
parser.add_argument(
'--pretrained_weight_path',
default='',
help='pretrained weight path')
parser.add_argument(
'--apex',
action='store_true',
help='Use apex for mixed precision training')
parser.add_argument(
'--apex-opt-level',
default='O1',
type=str,
help='For apex mixed precision training' 'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
parser.add_argument(
'--loss_scale_value',
default=1024,
type=int,
help='set loss scale value.')
parser.add_argument(
'--world-size',
default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs.')
parser.add_argument('--dist_rank', default=0, type=int, help='node rank for distributed training')
parser.add_argument('--seed',
default=123, type=int, help='Manually set random seed')
parser.add_argument('--weights-file', type=str)
parser.add_argument('--num-features', type=int, default=64)
parser.add_argument('--growth-rate', type=int, default=64)
parser.add_argument('--num-blocks', type=int, default=16)
parser.add_argument('--num-layers', type=int, default=8)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--patch-size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--nodes', type=int, default=1)
parser.add_argument('--gpus', type=int, default=1, help='num of gpus of per node')
parser.add_argument('--train_file', type=str, required=True)
parser.add_argument('--eval_file', type=str, required=True)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)