"""Script for multi-gpu training."""
import json
import os
import numpy as np
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.utils.data
from tqdm import tqdm
from alphapose.models import builder
from alphapose.opt import cfg, logger, opt
from alphapose.utils.logger import board_writing, debug_writing
from alphapose.utils.metrics import DataLogger, calc_accuracy, calc_integral_accuracy, evaluate_mAP
from alphapose.utils.transforms import get_func_heatmap_to_coord
import random
from apex import amp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import time
import torch.npu
from apex.optimizers import NpuFusedAdam
worker_num = 16
if opt.sync:
norm_layer = nn.SyncBatchNorm
else:
norm_layer = nn.BatchNorm2d
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', start_count_index=10):
self.name = name
self.fmt = fmt
self.reset()
self.start_count_index = start_count_index
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if self.count == 0:
self.N = n
self.val = val
self.count += n
if self.count > (self.start_count_index * self.N):
self.sum += val * n
self.avg = self.sum / (self.count - self.start_count_index * self.N)
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def train(opt, train_loader, m, criterion, optimizer,gpu):
batch_time = AverageMeter('Time', ':6.3f')
loss_logger = DataLogger()
acc_logger = DataLogger()
m.train()
norm_type = cfg.LOSS.get('NORM_TYPE', None)
train_loader = tqdm(train_loader, dynamic_ncols=True)
end = time.time()
for i, (inps, labels, label_masks, _, bboxes) in enumerate(train_loader):
if isinstance(inps, list):
inps = [inp.npu().requires_grad_() for inp in inps]
else:
inps = inps.npu().requires_grad_()
labels = labels.npu()
label_masks = label_masks.npu()
output = m(inps)
if cfg.LOSS.get('TYPE') == 'MSELoss':
loss = 0.5 * criterion(output.mul(label_masks), labels.mul(label_masks))
acc = calc_accuracy(output.mul(label_masks), labels.mul(label_masks))
else:
loss = criterion(output, labels, label_masks)
acc = calc_integral_accuracy(output, labels, label_masks, output_3d=False, norm_type=norm_type)
if isinstance(inps, list):
batch_size = inps[0].size(0)
else:
batch_size = inps.size(0)
loss_logger.update(loss.item(), batch_size)
acc_logger.update(acc, batch_size)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
opt.trainIters += 1
batch_time.update(time.time() - end)
end = time.time()
train_loader.set_description(
'loss: {loss:.8f} | acc: {acc:.4f}'.format(
loss=loss_logger.avg,
acc=acc_logger.avg)
)
if opt.world_size==1 or(opt.world_size>1 and gpu==0):
Batch_size=32
if batch_time.avg > 0:
logger.info('#####npu * FPS@all {:.3f}, TIME@all {:.3f}'.format(opt.world_size * Batch_size / batch_time.avg, batch_time.avg))
train_loader.close()
return loss_logger.avg, acc_logger.avg
def validate(gpu, m, opt, heatmap_to_coord, batch_size=20):
det_dataset = builder.build_dataset(cfg.DATASET.TEST, preset_cfg=cfg.DATA_PRESET, train=False, opt=opt)
det_loader = torch.utils.data.DataLoader(
det_dataset, batch_size=batch_size, shuffle=False, num_workers=worker_num, drop_last=False)
kpt_json = []
eval_joints = det_dataset.EVAL_JOINTS
m.eval()
norm_type = cfg.LOSS.get('NORM_TYPE', None)
hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
for inps, crop_bboxes, bboxes, img_ids, scores, imghts, imgwds in tqdm(det_loader, dynamic_ncols=True):
if isinstance(inps, list):
inps = [inp.npu() for inp in inps]
else:
inps = inps.npu()
output = m(inps)
pred = output
assert pred.dim() == 4
pred = pred[:, eval_joints, :, :]
for i in range(output.shape[0]):
bbox = crop_bboxes[i].tolist()
pose_coords, pose_scores = heatmap_to_coord(
pred[i][det_dataset.EVAL_JOINTS], bbox, hm_shape=hm_size, norm_type=norm_type)
keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
keypoints = keypoints.reshape(-1).tolist()
data = dict()
data['bbox'] = bboxes[i, 0].tolist()
data['image_id'] = int(img_ids[i])
data['score'] = float(scores[i] + np.mean(pose_scores) + np.max(pose_scores))
data['category_id'] = 1
data['keypoints'] = keypoints
kpt_json.append(data)
with open(os.path.join(opt.work_dir, 'test_kpt_{}.json'.format(gpu)), 'w') as fid:
json.dump(kpt_json, fid)
ANNfile = './exp/person_keypoints_val2017_{}.json'.format(gpu)
res = evaluate_mAP(os.path.join(opt.work_dir, 'test_kpt_{}.json'.format(gpu)), ann_type='keypoints', ann_file=os.path.join(cfg.DATASET.VAL.ROOT, cfg.DATASET.VAL.ANN))
return res
def validate_gt(gpu, m, opt, cfg, heatmap_to_coord, batch_size=20):
gt_val_dataset = builder.build_dataset(cfg.DATASET.VAL, preset_cfg=cfg.DATA_PRESET, train=False)
eval_joints = gt_val_dataset.EVAL_JOINTS
gt_val_loader = torch.utils.data.DataLoader(
gt_val_dataset, batch_size=batch_size, shuffle=False, num_workers=worker_num, drop_last=False)
kpt_json = []
m.eval()
norm_type = cfg.LOSS.get('NORM_TYPE', None)
hm_size = cfg.DATA_PRESET.HEATMAP_SIZE
for inps, labels, label_masks, img_ids, bboxes in tqdm(gt_val_loader, dynamic_ncols=True):
if isinstance(inps, list):
inps = [inp.npu() for inp in inps]
else:
inps = inps.npu()
output = m(inps)
pred = output
assert pred.dim() == 4
pred = pred[:, eval_joints, :, :]
for i in range(output.shape[0]):
bbox = bboxes[i].tolist()
pose_coords, pose_scores = heatmap_to_coord(
pred[i][gt_val_dataset.EVAL_JOINTS], bbox, hm_shape=hm_size, norm_type=norm_type)
keypoints = np.concatenate((pose_coords, pose_scores), axis=1)
keypoints = keypoints.reshape(-1).tolist()
data = dict()
data['bbox'] = bboxes[i].tolist()
data['image_id'] = int(img_ids[i])
data['score'] = float(np.mean(pose_scores) + np.max(pose_scores))
data['category_id'] = 1
data['keypoints'] = keypoints
kpt_json.append(data)
with open(os.path.join(opt.work_dir, 'test_gt_kpt_{}.json'.format(gpu)), 'w') as fid:
json.dump(kpt_json, fid)
ANNfile = './exp/person_keypoints_val2017_{}.json'.format(gpu)
res = evaluate_mAP(os.path.join(opt.work_dir, 'test_gt_kpt_{}.json'.format(gpu)), ann_type='keypoints', ann_file=os.path.join(cfg.DATASET.VAL.ROOT, cfg.DATASET.VAL.ANN))
return res
def main():
logger.info('******************************')
logger.info(opt)
logger.info('******************************')
logger.info(cfg)
logger.info('******************************')
if opt.seed is not None:
random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.manual_seed(0)
os.environ['MASTER_ADDR'] = opt.addr
os.environ['MASTER_PORT'] = '59629'
mp.spawn(parallel_train, nprocs=opt.world_size, args=(cfg, opt))
def parallel_train(gpu, cfg, opt):
torch.npu.set_device('npu:{}'.format(gpu))
dist.init_process_group(
backend='hccl',
world_size=opt.world_size,
rank=gpu
)
m = preset_model(cfg)
m = m.npu()
criterion = builder.build_loss(cfg.LOSS).npu()
if cfg.TRAIN.OPTIMIZER == 'adam':
optimizer = NpuFusedAdam(m.parameters(), lr=cfg.TRAIN.LR)
elif cfg.TRAIN.OPTIMIZER == 'rmsprop':
optimizer = torch.optim.RMSprop(m.parameters(), lr=cfg.TRAIN.LR)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=cfg.TRAIN.LR_STEP, gamma=cfg.TRAIN.LR_FACTOR)
train_dataset = builder.build_dataset(cfg.DATASET.TRAIN, preset_cfg=cfg.DATA_PRESET, train=True)
if opt.world_size > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=opt.world_size,
rank=gpu
)
else:
train_sampler=None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=(train_sampler is None), num_workers = worker_num,sampler=train_sampler)
heatmap_to_coord = get_func_heatmap_to_coord(cfg)
opt.trainIters = 0
m, optimizer= amp.initialize(m, optimizer, opt_level="O2",combine_grad=True)
m = nn.parallel.DistributedDataParallel(m, device_ids=[gpu], broadcast_buffers=False)
for i in range(cfg.TRAIN.BEGIN_EPOCH, cfg.TRAIN.END_EPOCH):
if opt.world_size > 1:
train_sampler.set_epoch(i)
opt.epoch = i
current_lr = optimizer.state_dict()['param_groups'][0]['lr']
logger.info(f'############# Starting Epoch {opt.epoch} | LR: {current_lr} #############')
loss, miou = train(opt, train_loader, m, criterion, optimizer,gpu)
logger.epochInfo('Train', opt.epoch, loss, miou)
lr_scheduler.step()
if (i + 1) % opt.snapshot == 0:
with torch.no_grad():
gt_AP = validate_gt(gpu, m, opt, cfg, heatmap_to_coord)
if opt.world_size==1 or gpu==0:
logger.info(f'##### Epoch {opt.epoch} | gt mAP: {gt_AP} #####')
torch.save(m.module.state_dict(), './exp/{}-{}/model_{}.pth'.format(opt.exp_id, cfg.FILE_NAME, opt.epoch))
logger.info(f'##### ./exp/{opt.exp_id}-{cfg.FILE_NAME}/model_{opt.epoch}.pth saved!#####')
if opt.world_size==1 or gpu==0:
torch.save(m.module.state_dict(), './exp/{}-{}/final_DPG.pth'.format(opt.exp_id, cfg.FILE_NAME))
def preset_model(cfg):
model = builder.build_sppe(cfg.MODEL, preset_cfg=cfg.DATA_PRESET)
if cfg.MODEL.PRETRAINED:
logger.info(f'Loading model from {cfg.MODEL.PRETRAINED}...')
model.load_state_dict(torch.load(cfg.MODEL.PRETRAINED))
elif cfg.MODEL.TRY_LOAD:
logger.info(f'Loading model from {cfg.MODEL.TRY_LOAD}...')
pretrained_state = torch.load(cfg.MODEL.TRY_LOAD)
model_state = model.state_dict()
pretrained_state = {k: v for k, v in pretrained_state.items()
if k in model_state and v.size() == model_state[k].size()}
model_state.update(pretrained_state)
model.load_state_dict(model_state)
else:
logger.info('Create new model')
logger.info('=> init weights')
model._initialize()
return model
if __name__ == "__main__":
main()