from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.utils.data as data
import os
import time
import torch
import argparse
import numpy as np
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.npu
from data.config import cfg
from pyramidbox import build_net
from layers.modules import MultiBoxLoss
from data.widerface import WIDERDetection, detection_collate
from torch.nn.parallel import DistributedDataParallel as DDP
import apex
from apex import amp
import torch.distributed as dist
import torch.multiprocessing as mp
parser = argparse.ArgumentParser(
description='Pyramidbox face Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--basenet',
default='vgg16_reducedfc.pth',
help='Pretrained base model')
parser.add_argument('--batch_size',
default=16, type=int,
help='Batch size for training')
parser.add_argument('--resume',
default=None, type=str,
help='Checkpoint state_dict file to resume training from')
parser.add_argument('--num_workers',
default=4, type=int,
help='Number of workers used in dataloading')
parser.add_argument('--npu',
default=True, type=bool,
help='Use NPU to train model')
parser.add_argument('--performance',
default=False, type=bool,
help='performance to train')
parser.add_argument('--lr', '--learning-rate',
default=1e-3, type=float,
help='initial learning rate')
parser.add_argument('--momentum',
default=0.9, type=float,
help='Momentum value for optim')
parser.add_argument('--weight_decay',
default=5e-4, type=float,
help='Weight decay for SGD')
parser.add_argument('--gamma',
default=0.1, type=float,
help='Gamma update for SGD')
parser.add_argument('--multinpu',
default=False, type=bool,
help='Use mutil Gpu training')
parser.add_argument('--save_folder',
default='weights/',
help='Directory for saving checkpoint models')
parser.add_argument('--local_rank',
default=-1, type=int,
help='rank for current process')
parser.add_argument('--world_size', default=-1, type=int,
help='number of distributed processes')
parser.add_argument('--device_list', default='0', type=str,
help='NPU id to use.')
args = parser.parse_args()
if args.npu:
if args.multinpu:
device_id = int(args.device_list.split(',')[args.local_rank])
device = 'npu:{}'.format(device_id)
else:
device = 'npu:0'
torch.npu.set_device(device)
if not os.path.exists(args.save_folder):
os.makedirs(args.save_folder)
train_dataset = WIDERDetection(cfg.FACE.TRAIN_FILE, mode='train')
val_dataset = WIDERDetection(cfg.FACE.VAL_FILE, mode='val')
val_batchsize = 1
val_loader = data.DataLoader(val_dataset, val_batchsize,
num_workers=1,
shuffle=False,
collate_fn=detection_collate,
pin_memory=True)
min_loss = np.inf
def train():
iteration = 0
start_epoch = 0
step_index = 0
per_epoch_size = len(train_dataset) // args.batch_size
if args.local_rank==0 or args.multinpu==False:
print('------build_net start-------')
pyramidbox = build_net('train', cfg.NUM_CLASSES)
if args.local_rank==0 or args.multinpu==False:
print('------build_net end-------')
net = pyramidbox
if args.multinpu:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=(train_sampler is None),
num_workers=args.num_workers,
pin_memory=False,
sampler=train_sampler,
collate_fn=detection_collate,
drop_last=True)
else:
train_loader = data.DataLoader(train_dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=detection_collate,
pin_memory=True)
if args.resume:
print('Resuming training, loading {}...'.format(args.resume))
start_epoch = net.load_weights(args.resume)
iteration = start_epoch * per_epoch_size
else:
vgg_weights = torch.load(args.save_folder + args.basenet)
if args.local_rank==0 or args.multinpu==False:
print('Load base network....')
net.vgg.load_state_dict(vgg_weights)
if args.local_rank==0 or args.multinpu==False:
print('load base network end--------')
if not args.resume:
if args.local_rank==0 or args.multinpu==False:
print('Initializing weights...')
pyramidbox.bn64.apply(pyramidbox.weights_init)
pyramidbox.extras.apply(pyramidbox.weights_init)
pyramidbox.lfpn_topdown.apply(pyramidbox.weights_init)
pyramidbox.lfpn_later.apply(pyramidbox.weights_init)
pyramidbox.cpm.apply(pyramidbox.weights_init)
pyramidbox.loc_layers.apply(pyramidbox.weights_init)
pyramidbox.conf_layers.apply(pyramidbox.weights_init)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
if args.npu:
net.npu()
net, optimizer = amp.initialize(net, optimizer, opt_level="O1",loss_scale=64.0)
if args.multinpu:
device_id = int(args.device_list.split(',')[args.local_rank])
device = 'npu:{}'.format(device_id)
net = DDP(net, device_ids=[device_id],broadcast_buffers=False)
cudnn.benckmark = True
criterion1 = MultiBoxLoss(cfg, args.npu)
criterion2 = MultiBoxLoss(cfg, args.npu, use_head_loss=True)
if args.local_rank==0 or args.multinpu==False:
print('Loading wider dataset...')
print('Using the specified args:')
print(args)
warmup_steps = 1000
net.train()
if args.local_rank==0 or args.multinpu==False:
print('start train--------')
for epoch in range(start_epoch, cfg.EPOCHES):
if args.multinpu:
train_sampler.set_epoch(epoch)
losses = 0
for batch_idx, (images, face_targets, head_targets) in enumerate(train_loader):
if args.npu:
images = Variable(images.npu())
with torch.no_grad():
face_targets = [Variable(ann) for ann in face_targets]
head_targets = [Variable(ann) for ann in head_targets]
else:
images = Variable(images)
with torch.no_grad():
face_targets = [Variable(ann) for ann in face_targets]
head_targets = [Variable(ann) for ann in head_targets]
adjust_learning_rate(optimizer,iteration,warmup_steps,15000)
t0 = time.time()
out = net(images)
optimizer.zero_grad()
face_loss_l, face_loss_c = criterion1(out, face_targets)
head_loss_l, head_loss_c = criterion2(out, head_targets)
loss = face_loss_l + face_loss_c + head_loss_l + head_loss_c
losses += loss.item()
if args.npu:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
t1 = time.time()
face_loss = (face_loss_l + face_loss_c).item()
head_loss = (head_loss_l + head_loss_c).item()
if args.performance:
if iteration == 50:
t50_0 = time.time()
if iteration == 100:
t100_0 = time.time()
if args.multinpu:
if args.local_rank==0:
print('cost time:{} batch_size:{} num_gpu:{} FPS:{}'.format(t100_0-t50_0,args.batch_size,args.world_size,(50*args.batch_size*args.world_size)/(t100_0-t50_0)))
else:
print('cost time:{} batch_size:{} FPS:{}'.format(t100_0-t50_0,args.batch_size,(50*args.batch_size)/(t100_0-t50_0)))
if iteration == 110:
break
iteration += 1
continue
if iteration % 10 == 0 and (args.local_rank==0 or args.multinpu==False):
loss_ = losses / (batch_idx + 1)
print('Timer: {:.4f} sec.'.format(t1 - t0))
print('epoch ' + repr(epoch) + ' iter ' +
repr(iteration) + ' || Loss:%.4f' % (loss_))
print('->> face Loss: {:.4f} || head loss : {:.4f}'.format(
face_loss, head_loss))
print('->> lr: {}'.format(optimizer.param_groups[0]['lr']))
if args.multinpu:
print('iter:{} cost time:{} batch_size:{} num_gpu:{} FPS:{}'.format(iteration,t1-t0,args.batch_size,args.world_size,(args.batch_size*args.world_size)/(t1-t0)))
else:
print('iter:{} cost time:{} batch_size:{} FPS:{}'.format(iteration,t1-t0,args.batch_size,args.batch_size/(t1-t0)))
if iteration != 0 and iteration % 2000 == 0 and (args.local_rank==0 or args.multinpu==False):
print('Saving state, iter:', iteration)
file = 'pyramidbox_' + repr(iteration) + '.pth'
torch.save(pyramidbox.state_dict(),
os.path.join(args.save_folder, file))
iteration += 1
if args.performance:
break
if epoch>50 and epoch%5==0:
val(epoch, net, pyramidbox, criterion1, criterion2)
net.train()
def val(epoch,
net,
pyramidbox,
criterion1,
criterion2):
net.eval()
face_losses = 0
head_losses = 0
step = 0
t1 = time.time()
for batch_idx, (images, face_targets, head_targets) in enumerate(val_loader):
if args.npu:
images = Variable(images.npu())
with torch.no_grad():
face_targets = [Variable(ann) for ann in face_targets]
head_targets = [Variable(ann) for ann in head_targets]
else:
images = Variable(images)
with torch.no_grad():
face_targets = [Variable(ann)
for ann in face_targets]
head_targets = [Variable(ann)
for ann in head_targets]
out = net(images)
face_loss_l, face_loss_c = criterion1(out, face_targets)
head_loss_l, head_loss_c = criterion2(out, head_targets)
face_losses += (face_loss_l + face_loss_c).item()
head_losses += (head_loss_l + head_loss_c).item()
step += 1
tloss = face_losses / step
t2 = time.time()
if args.local_rank==0:
print('test Timer:{:.4f} .sec'.format(t2 - t1))
print('epoch ' + repr(epoch) + ' || Loss:%.4f' % (tloss))
global min_loss
if tloss < min_loss and args.local_rank==0:
print('Saving best state,epoch', epoch)
torch.save(pyramidbox.state_dict(), os.path.join(
args.save_folder, 'pyramidbox.pth'))
min_loss = tloss
states = {
'epoch': epoch,
'weight': pyramidbox.state_dict(),
}
if args.local_rank==0:
torch.save(states, os.path.join(
args.save_folder, 'pyramidbox_checkpoint.pth'))
def lr_warmup(optimizer,step,base_lr,warmup_steps):
if not step <warmup_steps:
return
lr = base_lr*(step+1)/warmup_steps
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def adjust_learning_rate(optimizer,step,warmup_step,total_step):
"""Sets the learning rate to the initial LR decayed by 10 at every
specified step
# Adapted from PyTorch Imagenet example:
# https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
lr = 0.5 * (1 + np.cos(np.pi * (step - warmup_step) / (total_step - warmup_step))) * args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if __name__ == '__main__':
if args.multinpu:
print('in multi--------')
dist.init_process_group(backend='hccl', world_size=args.world_size, rank=args.local_rank)
train()
else:
print('in train------')
train()