from base_model import Loss
import sys
import os
from opt_loss import OptLoss
from mlperf_logger import configure_logger, log_start, log_end, log_event, set_seeds, get_rank, barrier
from mlperf_logging.mllog import constants
import torch
from torch.autograd import Variable
import time
import numpy as np
import io
from bisect import bisect
from apex import amp
from ssd300 import SSD300
from master_params import create_flat_master
from parse_config import parse_args, validate_arguments, validate_group_bn
from data.build_pipeline import prebuild_pipeline, build_pipeline
from box_coder import dboxes300_coco, build_ssd300_coder
from async_evaluator import AsyncEvaluator
from eval import coco_eval
from apex.optimizers import NpuFusedSGD
import gc
from torch.nn.parallel import DistributedDataParallel
import torch.utils.data.distributed
import torch.distributed as dist
try:
import apex_C
import apex
from apex.parallel.LARC import LARC
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install APEX from https://github.com/nvidia/apex")
from contextlib import redirect_stdout
import logging
class Logger(object):
logfile = ""
def __init__(self, filename=""):
self.logfile = filename
self.terminal = sys.stdout
return
def write(self, message):
self.terminal.write(message)
if self.logfile != "":
try:
self.log = open(self.logfile, "a")
self.log.write(message)
self.log.close()
except:
pass
def flush(self):
pass
def print_message(rank, *print_args):
if rank == 0:
print(*print_args)
def load_checkpoint(model, checkpoint):
print("loading model checkpoint", checkpoint)
od = torch.load(checkpoint)
saved_model = od["model"]
for k in list(saved_model.keys()):
if k.startswith('module.'):
saved_model[k[7:]] = saved_model.pop(k)
if k.startswith('mbox.'):
saved_model.pop(k)
model.load_state_dict(saved_model,strict=False)
def check_async_evals(args, evaluator, threshold):
finished = 0
if args.rank == 0:
for epoch, current_accuracy in evaluator.finished_tasks().items():
log_start(key=constants.EVAL_START, metadata={'epoch_num' : epoch})
log_event(key=constants.EVAL_ACCURACY,
value=current_accuracy,
metadata={'epoch_num' : epoch})
log_end(key=constants.EVAL_STOP, metadata={'epoch_num' : epoch})
if current_accuracy >= threshold:
finished = 1
if not args.distributed:
return finished == 1
with torch.no_grad():
finish_tensor = torch.tensor([finished], dtype=torch.int32, device=torch.device('npu'))
torch.distributed.broadcast(finish_tensor, src=0)
if finish_tensor.item() >= 1:
return True
return False
def lr_warmup(optim, warmup_iter, iter_num, epoch, base_lr, args):
if iter_num < warmup_iter:
warmup_step = base_lr / (warmup_iter * (2 ** args.warmup_factor))
new_lr = base_lr - (warmup_iter - iter_num) * warmup_step
for param_group in optim.param_groups:
param_group['lr'] = new_lr
def setup_distributed(args):
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
if args.distributed:
torch.npu.set_device(args.local_rank)
torch.distributed.init_process_group(backend='hccl',
world_size=int(os.environ['WORLD_SIZE']),
rank=args.local_rank,
)
args.local_seed = set_seeds(args)
if args.distributed:
args.N_gpu = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
args.N_gpu = 1
args.rank = 0
validate_group_bn(args.bn_group)
return args
def train300_mlperf_coco(args):
args = setup_distributed(args)
if not args.distributed:
torch.npu.set_device(args.device_id)
model_options = {
'use_nhwc' : args.nhwc,
'pad_input' : args.pad_input,
'bn_group' : args.bn_group,
}
ssd300 = SSD300(args, args.num_classes, **model_options)
if args.checkpoint is not None:
load_checkpoint(ssd300, args.checkpoint)
ssd300.train()
ssd300.npu()
dboxes = dboxes300_coco()
loss_func = Loss(dboxes)
loss_func.npu()
global_batch_size = (args.N_gpu * args.batch_size)
log_event(key=constants.MODEL_BN_SPAN, value=args.bn_group*args.batch_size)
log_event(key=constants.GLOBAL_BATCH_SIZE, value=global_batch_size)
base_lr = 2.5e-3
requested_lr_multiplier = args.lr / base_lr
adjusted_multiplier = max(1, round(requested_lr_multiplier * global_batch_size / 32))
current_lr = base_lr * adjusted_multiplier
current_momentum = 0.9
current_weight_decay = args.wd
static_loss_scale = args.loss_scale
optim = apex.optimizers.NpuFusedSGD(ssd300.parameters(),
lr=current_lr,
momentum=current_momentum,
weight_decay=current_weight_decay)
ssd300, optim = amp.initialize(ssd300, optim, opt_level='O2', loss_scale=static_loss_scale,combine_grad=True)
if args.distributed:
if args.delay_allreduce:
print_message(args.local_rank, "Delaying allreduces to the end of backward()")
ssd300 = DistributedDataParallel(ssd300, device_ids=[args.local_rank])
log_event(key=constants.OPT_BASE_LR, value=current_lr)
log_event(key=constants.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=args.lr_decay_epochs)
log_event(key=constants.OPT_LR_DECAY_STEPS, value=args.lr_decay_epochs)
log_event(key=constants.OPT_WEIGHT_DECAY, value=current_weight_decay)
if args.warmup is not None:
log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup)
log_event(key=constants.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)
ssd300_eval = SSD300(args, args.num_classes, **model_options).npu()
if args.use_fp16:
convert_network(ssd300_eval, torch.half)
train_model = ssd300.module if args.distributed else ssd300
ssd300_eval.load_state_dict(train_model.state_dict())
ssd300_eval.eval()
print_message(args.local_rank, "epoch", "nbatch", "loss")
iter_num = args.iteration
avg_loss = 0.0
start_elapsed_time = time.time()
last_printed_iter = args.iteration
num_elapsed_samples = 0
input_c = 4 if args.pad_input else 3
example_shape = [args.batch_size, 300, 300, input_c] if args.nhwc else [args.batch_size, input_c, 300, 300]
example_input = torch.randn(*example_shape).npu()
if args.use_fp16:
example_input = example_input.half()
if args.jit:
module_to_jit = ssd300.module if args.distributed else ssd300
if args.distributed:
ssd300.module = torch.jit.trace(module_to_jit, example_input, check_trace=False)
else:
ssd300 = torch.jit.trace(module_to_jit, example_input, check_trace=False)
ssd300_eval = torch.jit.trace(ssd300_eval, example_input, check_trace=False)
ploc, plabel = ssd300(example_input)
loss = ploc[0,0,0] + plabel[0,0,0]
dloss = torch.randn_like(loss)
loss.backward(dloss)
encoder = build_ssd300_coder()
evaluator = AsyncEvaluator(num_threads=1)
log_end(key=constants.INIT_STOP)
barrier()
log_start(key=constants.RUN_START)
barrier()
train_pipe = prebuild_pipeline(args)
train_loader, epoch_size = build_pipeline(args, training=True, pipe=train_pipe)
if args.rank == 0:
print("epoch size is: ", epoch_size, " images")
val_loader, inv_map, cocoGt = build_pipeline(args, training=False)
if args.profile_gc_off:
gc.disable()
gc.collect()
i_eval = 0
block_start_epoch = 1
log_start(key=constants.BLOCK_START,
metadata={'first_epoch_num': block_start_epoch,
'epoch_count': args.evaluation[i_eval]})
for epoch in range(args.epochs):
optim.zero_grad()
if epoch in args.evaluation:
train_model = ssd300.module if args.distributed else ssd300
if args.distributed and args.allreduce_running_stats:
if args.rank == 0: print("averaging bn running means and vars")
world_size = float(torch.distributed.get_world_size())
for bn_name, bn_buf in train_model.named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= world_size
if args.rank == 0:
if args.save:
print("saving model...")
if not os.path.isdir('./models'):
os.mkdir('./models')
torch.save({"model" : ssd300.state_dict()}, "./models/iter_{}.pt".format(iter_num))
ssd300_eval.load_state_dict(train_model.state_dict())
coco_eval(args,
ssd300_eval,
val_loader,
cocoGt,
encoder,
inv_map,
epoch,
iter_num,
evaluator=evaluator)
log_end(key=constants.BLOCK_STOP, metadata={'first_epoch_num': block_start_epoch})
if epoch != max(args.evaluation):
i_eval += 1
block_start_epoch = epoch + 1
log_start(key=constants.BLOCK_START,
metadata={'first_epoch_num': block_start_epoch,
'epoch_count': (args.evaluation[i_eval] -
args.evaluation[i_eval - 1])})
if epoch in args.lr_decay_epochs:
current_lr *= args.lr_decay_factor
print_message(args.rank, "lr decay step #" + str(bisect(args.lr_decay_epochs, epoch)))
for param_group in optim.param_groups:
param_group['lr'] = current_lr
log_start(key=constants.EPOCH_START,
metadata={'epoch_num': epoch + 1,
'current_iter_nufm': iter_num})
for i, data in enumerate(train_loader):
(img, bbox, label, _) = data
img = img.npu()
bbox = bbox.npu()
label = label.npu()
if args.profile_start is not None and iter_num == args.profile_start:
torch.npu.profiler.start()
torch.npu.synchronize()
if args.profile_nvtx:
torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)
if args.profile is not None and iter_num == args.profile:
if args.profile_start is not None and iter_num >=args.profile_start:
if args.profile_nvtx:
torch.autograd._disable_profiler()
torch.npu.profiler.stop()
return
if args.warmup is not None:
lr_warmup(optim, args.warmup, iter_num, epoch, current_lr, args)
if (img is None) or (bbox is None) or (label is None):
print("No labels in batch")
continue
ploc, plabel = ssd300(img)
ploc, plabel = ploc.float(), plabel.float()
N = img.shape[0]
bbox.requires_grad = False
label.requires_grad = False
bbox = bbox.view(N, -1, 4).transpose(1,2).contiguous()
label = label.view(N, -1).long()
loss = loss_func(ploc, plabel, bbox, label)
if np.isfinite(loss.item()):
avg_loss = 0.999*avg_loss + 0.001*loss.item()
else:
print("model exploded (corrupted by Inf or Nan)")
sys.exit()
num_elapsed_samples += N
if args.rank == 0 and iter_num % args.print_interval == 0:
end_elapsed_time = time.time()
elapsed_time = end_elapsed_time - start_elapsed_time
avg_samples_per_sec = num_elapsed_samples * args.N_gpu / elapsed_time
print("Epoch:{:4d}, Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f}, avg. samples / sec: {:.2f}"\
.format(epoch, iter_num, loss.item(), avg_loss, avg_samples_per_sec), end="\n")
last_printed_iter = iter_num
start_elapsed_time = time.time()
num_elapsed_samples = 0
with amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
optim.step()
optim.zero_grad()
if iter_num % 20 == 0:
finished = check_async_evals(args, evaluator, args.threshold)
if finished:
return True
iter_num += 1
log_end(key=constants.EPOCH_STOP, metadata={'epoch_num': epoch + 1})
return False
def main():
configure_logger(constants.SSD)
log_start(key=constants.INIT_START, log_all_ranks=True)
args = parse_args()
sys.stdout = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))
sys.stderr = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))
if args.local_rank == 0:
print(args)
args.evaluation.sort()
args.lr_decay_epochs.sort()
validate_arguments(args)
torch.set_num_threads(1)
torch.backends.cudnn.benchmark = not args.profile_cudnn_get
success = train300_mlperf_coco(args)
status = 'success' if success else 'aborted'
log_end(key=constants.RUN_STOP, metadata={'status': status})
if __name__ == "__main__":
main()