import os
import shutil
import argparse
import sys
import time
import random
import warnings
import torch
if torch.__version__ >= '1.8.1':
import torch_npu
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.distributed as dist
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from apex import amp
import apex
import models
best_acc1 = 0
def main():
parser = argparse.ArgumentParser(description='PyTorch Se-ResNeXt101 ImageNet Training')
parser.add_argument('--data_path', metavar='DIR', default='/opt/npu/imagenet',
help='path to dataset')
parser.add_argument('--workers', default=192, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch-size', default=128, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', default=0.6, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--amp', default=True, action='store_true',
help='use amp to train the model')
parser.add_argument('--opt-level', default="O2", type=str, help='apex optimize level')
parser.add_argument('--loss-scale-value', default=None, type=float, help='static loss scale value')
parser.add_argument('--combine-grad', default=True, action='store_true',
help='use amp to train the model')
parser.add_argument('--ddp',
dest='ddp',
action='store_true',
help='use distribution training')
parser.add_argument('--nodes', default=1, type=int,
help='number of data loading workers (default: 4)')
parser.add_argument('--node_rank', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('--device_list', default='0,1,2,3,4,5,6,7', type=str, help='device id list')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--evaluate',
dest='evaluate',
action='store_true',
help='evaluate model on validation set')
parser.add_argument('--prof',
dest='prof',
action='store_true',
help='print model profile on training')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--stop-step-num', default=None, type=int,
help='after the stop-step, killing the training task')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num_classes', default=1000, type=int,
help='The number of classes.')
args = parser.parse_args()
print(args)
args.distributed = (args.nodes > 1) or args.ddp
args.process_device_map = device_id_to_process_device_map(args.device_list)
if args.distributed:
ngpus_per_node = len(args.process_device_map)
else:
ngpus_per_node = 1
print("Use multiprocessing for training :", args.distributed)
args.world_size = ngpus_per_node * args.nodes
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '12345'
if args.seed is not None:
seed_everything(args.seed)
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.distributed:
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
main_worker(0, ngpus_per_node, args)
def main_worker(process_id, ngpus_per_node, args):
global best_acc1
deviceid = args.process_device_map[process_id]
loc = "npu:{}".format(deviceid)
if deviceid is not None:
print("Use NPU: {} for training".format(deviceid))
args.rank = args.node_rank * ngpus_per_node + process_id
args.mainprocess = (args.distributed is False) or (args.rank == 0)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.loader_workers = int(args.workers / ngpus_per_node)
if args.distributed:
dist.init_process_group(backend = 'hccl',
world_size = args.world_size,
rank = args.rank)
torch.npu.set_device(loc)
model = models.seresnext101_32x4d(num_classes=args.num_classes)
model = model.npu()
if os.path.isfile(args.resume):
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
state_dict = checkpoint['state_dict']
if args.pretrained:
if "fc.weight" in state_dict:
state_dict.pop('fc.weight')
state_dict.pop('fc.bias')
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict)
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
args.start_epoch = 0
print("=> no checkpoint found at '{}'".format(args.resume))
optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss().npu()
model, optimizer = amp.initialize(model,
optimizer,
opt_level = args.opt_level,
combine_grad = args.combine_grad,
loss_scale = 'dynamic')
if args.distributed:
if args.pretrained:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids = [deviceid], broadcast_buffers = False,
find_unused_parameters=True)
else:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids = [deviceid], broadcast_buffers = False)
cudnn.benchmark = True
train_loader, val_loader, train_sampler = dataloader(args)
if args.evaluate:
acc1 = validate(val_loader, model, criterion, deviceid, args)
return
if args.prof:
profiling(train_loader, model, criterion, optimizer, deviceid, args)
return
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch, args)
if args.mainprocess:
print("lr adjust to :", optimizer.param_groups[0]['lr'])
train(train_loader, model, criterion, optimizer, epoch, deviceid, args)
if args.stop_step_num is not None:
break
acc1 = validate(val_loader, model, criterion, deviceid, args)
if args.mainprocess:
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
file_name = "checkpoint"
if args.distributed:
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
save_checkpoint(epoch, {
'epoch': epoch + 1,
'arch': "seresnext101",
'state_dict': state_dict,
'best_acc1': best_acc1,
}, is_best, file_name)
def profiling(train_loader, model, criterion, optimizer, deviceid, args, epoch=0):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time],
prefix="Epoch: [{}]".format(epoch))
model.train()
optimizer.zero_grad()
end = time.time()
images, target = next(iter(train_loader))
images = images.npu(non_blocking=True)
target = target.long().npu(non_blocking=True)
def onestep(images, target, model, criterion, optimizer):
output = model(images)
loss = criterion(output, target.long())
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()
return output, target, loss
for i in range(100):
if args.mainprocess:
data_time.update(time.time() - end)
output, target, loss = onestep(images, target, model, criterion, optimizer)
if args.mainprocess:
batch_time.update(time.time() - end)
end = time.time()
progress.display(i)
if args.mainprocess:
print("start log prof")
with torch.autograd.profiler.profile(use_npu=True) as prof:
output, target, loss = onestep(images, target, model, criterion, optimizer)
print(prof.key_averages().table())
prof.export_chrome_trace("output.prof")
def train(train_loader, model, criterion, optimizer, epoch, deviceid, args):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
model.train()
optimizer.zero_grad()
end = time.time()
for i, (images, target) in enumerate(train_loader):
images = images.npu(non_blocking=True)
target = target.long().npu(non_blocking=True)
if args.mainprocess:
data_time.update(time.time() - end)
output = model(images)
loss = criterion(output, target.long())
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()
if args.mainprocess:
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0 and args.mainprocess:
progress.display(i)
if args.stop_step_num is not None and i >= args.stop_step_num:
break
if args.mainprocess and batch_time.avg:
print("[npu id:", deviceid, "]", "batch_size:", args.world_size * args.batch_size,
'Time: {:.3f}'.format(batch_time.avg), '* FPS@all {:.3f}'.format(
args.batch_size * args.world_size / batch_time.avg))
def validate(val_loader, model, criterion, deviceid, args):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.npu(non_blocking=True)
target = target.to(torch.int32).npu(non_blocking=True)
output = model(images)
loss = criterion(output, target.long())
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.mainprocess:
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0 and args.mainprocess:
progress.display(i)
if args.mainprocess:
print("[gpu id:", deviceid, "]", '[AVG-ACC] * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
def save_checkpoint(epoch, state, is_best, filename='checkpoint'):
filename2 = filename + ".pth"
torch.save(state, filename2)
if is_best:
shutil.copyfile(filename2, filename + 'model_best.pth')
def dataloader(args):
""" Create training & validation dataloader """
traindir = os.path.join(args.data_path, 'train')
valdir = os.path.join(args.data_path, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
num_replicas=args.world_size,
rank=args.rank)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.loader_workers,
drop_last=True,
pin_memory=False,
sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(dataset = val_dataset,
batch_size=args.batch_size,
shuffle=False,
drop_last=True,
num_workers=args.loader_workers,
pin_memory=False,
sampler=None)
return train_loader, val_loader, train_sampler
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
temp = correct[:k]
correct_k = temp.reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def adjust_learning_rate(optimizer, epoch, args):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = 10
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.batchsize = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.batchsize):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.batchsize)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
"""process the data saved by AverageMeter"""
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def flush_print(func):
"""for servers to immediately record the logs"""
def new_print(*args, **kwargs):
"""set a new system standary outpu flush"""
func(*args, **kwargs)
sys.stdout.flush()
return new_print
def device_id_to_process_device_map(device_list):
devices = device_list.split(",")
devices = [int(x) for x in devices]
devices.sort()
process_device_map = dict()
for process_id, device_id in enumerate(devices):
process_device_map[process_id] = device_id
return process_device_map
def seed_everything(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
if __name__ == '__main__':
main()