import sys
import os
import stat
import time
import io
import gc
import logging
from bisect import bisect
from contextlib import redirect_stdout
from opt_loss import OptLoss
from mlperf_logger import configure_logger, log_start, log_end, log_event, set_seeds, get_rank, barrier
from mlperf_logging.mllog import constants
import torch
from torch.autograd import Variable
from base_model import Loss
from apex import amp
from ssd300 import SSD300
from master_params import create_flat_master
from parse_config import parse_args, validate_arguments, validate_group_bn
from data.build_pipeline import prebuild_pipeline, build_pipeline
from box_coder import dboxes300_coco, build_ssd300_coder
from async_evaluator import AsyncEvaluator
from eval import coco_eval
from apex.optimizers import NpuFusedSGD
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import torch.utils.data.distributed
import torch.distributed as dist
if torch.__version__ >= '1.8':
import torch_npu
try:
from torch_npu.utils.profiler import Profile
except Exception:
print("Profile not in torch_npu.utils.profiler now.. Auto Profile disabled.", flush=True)
class Profile:
def __init__(self, *args, **kwargs):
pass
def start(self):
pass
def end(self):
pass
try:
import apex_C
import apex
from apex.parallel.LARC import LARC
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import convert_network
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install APEX from https://github.com/nvidia/apex")
class Logger(object):
logfile = ""
def __init__(self, filename=""):
self.logfile = filename
self.terminal = sys.stdout
return
def write(self, message):
self.terminal.write(message)
if self.logfile != "":
try:
fd = os.open(self.logfile, os.O_RDWR|os.O_CREAT, stat.S_IRWXU)
self.log = os.fdopen(fd, "a")
self.log.write(message)
self.log.close(fd)
except Exception:
pass
def flush(self):
pass
def print_message(rank, *print_args):
if rank == 0:
print(*print_args)
def load_checkpoint(model, checkpoint):
print("loading model checkpoint", checkpoint)
od = torch.load(checkpoint)
saved_model = od["model"]
for k in list(saved_model.keys()):
if k.startswith('module.'):
saved_model[k[7:]] = saved_model.pop(k)
model.load_state_dict(saved_model)
def check_async_evals(args, evaluator, threshold):
finished = 0
if args.rank == 0:
for epoch, current_accuracy in evaluator.finished_tasks().items():
log_start(key=constants.EVAL_START, metadata={'epoch_num' : epoch})
log_event(key=constants.EVAL_ACCURACY,
value=current_accuracy,
metadata={'epoch_num' : epoch})
log_end(key=constants.EVAL_STOP, metadata={'epoch_num' : epoch})
if current_accuracy >= threshold:
finished = 1
if not args.distributed:
return finished == 1
with torch.no_grad():
finish_tensor = torch.tensor([finished], dtype=torch.int32, device=torch.device('npu'))
torch.distributed.broadcast(finish_tensor, src=0)
if finish_tensor.item() >= 1:
return True
return False
def lr_warmup(optim, warmup_iter, iter_num, epoch, base_lr, args):
if iter_num < warmup_iter:
warmup_step = base_lr / (warmup_iter * (2 ** args.warmup_factor))
new_lr = base_lr - (warmup_iter - iter_num) * warmup_step
for param_group in optim.param_groups:
param_group['lr'] = new_lr
def setup_distributed(args):
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
if args.distributed:
torch.npu.set_device(args.local_rank)
torch.distributed.init_process_group(backend='hccl',
world_size=int(os.environ['WORLD_SIZE']),
rank=args.local_rank,
)
args.local_seed = set_seeds(args)
if args.distributed:
args.N_gpu = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
else:
args.N_gpu = 1
args.rank = 0
validate_group_bn(args.bn_group)
return args
def train300_mlperf_coco(args):
args = setup_distributed(args)
if not args.distributed:
torch.npu.set_device(args.device_id)
model_options = {
'use_nhwc' : args.nhwc,
'pad_input' : args.pad_input,
'bn_group' : args.bn_group,
}
ssd300 = SSD300(args, args.num_classes, **model_options)
if args.checkpoint is not None:
load_checkpoint(ssd300, args.checkpoint)
ssd300.train()
ssd300.npu()
dboxes = dboxes300_coco()
loss_func = Loss(dboxes)
loss_func.npu()
global_batch_size = (args.N_gpu * args.batch_size)
log_event(key=constants.MODEL_BN_SPAN, value=args.bn_group*args.batch_size)
log_event(key=constants.GLOBAL_BATCH_SIZE, value=global_batch_size)
base_lr = 2.5e-3
requested_lr_multiplier = args.lr / base_lr
adjusted_multiplier = max(1, round(requested_lr_multiplier * global_batch_size / 32))
current_lr = base_lr * adjusted_multiplier
current_momentum = 0.9
current_weight_decay = args.wd
static_loss_scale = args.loss_scale
optim = apex.optimizers.NpuFusedSGD(ssd300.parameters(),
lr=current_lr,
momentum=current_momentum,
weight_decay=current_weight_decay)
ssd300, optim = amp.initialize(ssd300, optim, opt_level='O2', loss_scale=static_loss_scale,combine_grad=True)
if args.distributed:
if args.delay_allreduce:
print_message(args.local_rank, "Delaying allreduces to the end of backward()")
ssd300 = DistributedDataParallel(ssd300, device_ids=[args.local_rank])
log_event(key=constants.OPT_BASE_LR, value=current_lr)
log_event(key=constants.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=args.lr_decay_epochs)
log_event(key=constants.OPT_LR_DECAY_STEPS, value=args.lr_decay_epochs)
log_event(key=constants.OPT_WEIGHT_DECAY, value=current_weight_decay)
if args.warmup is not None:
log_event(key=constants.OPT_LR_WARMUP_STEPS, value=args.warmup)
log_event(key=constants.OPT_LR_WARMUP_FACTOR, value=args.warmup_factor)
ssd300_eval = SSD300(args, args.num_classes, **model_options).npu()
if args.use_fp16:
convert_network(ssd300_eval, torch.half)
train_model = ssd300.module if args.distributed else ssd300
ssd300_eval.load_state_dict(train_model.state_dict())
ssd300_eval.eval()
print_message(args.local_rank, "epoch", "nbatch", "loss")
iter_num = args.iteration
avg_loss = 0.0
start_elapsed_time = time.time()
last_printed_iter = args.iteration
num_elapsed_samples = 0
input_c = 4 if args.pad_input else 3
example_shape = [args.batch_size, 300, 300, input_c] if args.nhwc else [args.batch_size, input_c, 300, 300]
example_input = torch.randn(*example_shape).npu()
if args.use_fp16:
example_input = example_input.half()
if args.jit:
module_to_jit = ssd300.module if args.distributed else ssd300
if args.distributed:
ssd300.module = torch.jit.trace(module_to_jit, example_input, check_trace=False)
else:
ssd300 = torch.jit.trace(module_to_jit, example_input, check_trace=False)
ssd300_eval = torch.jit.trace(ssd300_eval, example_input, check_trace=False)
ploc, plabel = ssd300(example_input)
loss = ploc[0,0,0] + plabel[0,0,0]
dloss = torch.randn_like(loss)
loss.backward(dloss)
encoder = build_ssd300_coder()
evaluator = AsyncEvaluator(num_threads=1)
log_end(key=constants.INIT_STOP)
barrier()
log_start(key=constants.RUN_START)
barrier()
train_pipe = prebuild_pipeline(args)
train_loader, epoch_size = build_pipeline(args, training=True, pipe=train_pipe)
if args.rank == 0:
print("epoch size is: ", epoch_size, " images")
val_loader, inv_map, cocoGt = build_pipeline(args, training=False)
if args.profile_gc_off:
gc.disable()
gc.collect()
i_eval = 0
block_start_epoch = 1
log_start(key=constants.BLOCK_START,
metadata={'first_epoch_num': block_start_epoch,
'epoch_count': args.evaluation[i_eval]})
for epoch in range(args.epochs):
optim.zero_grad()
if epoch in args.evaluation:
train_model = ssd300.module if args.distributed else ssd300
if args.distributed and args.allreduce_running_stats:
if args.rank == 0:
print("averaging bn running means and vars")
world_size = float(torch.distributed.get_world_size())
for bn_name, bn_buf in train_model.named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= world_size
if args.rank == 0:
if args.save:
print("saving model...")
if not os.path.isdir('./models'):
os.mkdir('./models')
torch.save({"model" : ssd300.state_dict()}, "./models/iter_{}.pt".format(iter_num))
ssd300_eval.load_state_dict(train_model.state_dict())
coco_eval(args,
ssd300_eval,
val_loader,
cocoGt,
encoder,
inv_map,
epoch,
iter_num,
evaluator=evaluator)
log_end(key=constants.BLOCK_STOP, metadata={'first_epoch_num': block_start_epoch})
if epoch != max(args.evaluation):
i_eval += 1
block_start_epoch = epoch + 1
log_start(key=constants.BLOCK_START,
metadata={'first_epoch_num': block_start_epoch,
'epoch_count': (args.evaluation[i_eval] -
args.evaluation[i_eval - 1])})
if epoch in args.lr_decay_epochs:
current_lr *= args.lr_decay_factor
print_message(args.rank, "lr decay step #" + str(bisect(args.lr_decay_epochs, epoch)))
for param_group in optim.param_groups:
param_group['lr'] = current_lr
log_start(key=constants.EPOCH_START,
metadata={'epoch_num': epoch + 1,
'current_iter_nufm': iter_num})
profile = Profile(start_step=int(os.getenv('PROFILE_START_STEP', 10)),
profile_type=os.getenv('PROFILE_TYPE'))
for i, data in enumerate(train_loader):
(img, bbox, label, _) = data
img = img.npu()
bbox = bbox.npu()
label = label.npu()
if args.profile_start is not None and iter_num == args.profile_start:
torch.npu.profiler.start()
torch.npu.synchronize()
if args.profile_nvtx:
torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX)
if args.profile is not None and iter_num == args.profile:
if args.profile_start is not None and iter_num >=args.profile_start:
if args.profile_nvtx:
torch.autograd._disable_profiler()
torch.npu.profiler.stop()
sys.exit()
if args.warmup is not None:
lr_warmup(optim, args.warmup, iter_num, epoch, current_lr, args)
if (img is None) or (bbox is None) or (label is None):
print("No labels in batch")
continue
profile.start()
ploc, plabel = ssd300(img)
ploc, plabel = ploc.float(), plabel.float()
N = img.shape[0]
bbox.requires_grad = False
label.requires_grad = False
bbox = bbox.view(N, -1, 4).transpose(1,2).contiguous()
label = label.view(N, -1).long()
loss = loss_func(ploc, plabel, bbox, label)
if np.isfinite(loss.item()):
avg_loss = 0.999*avg_loss + 0.001*loss.item()
else:
print("model exploded (corrupted by Inf or Nan)")
sys.exit()
num_elapsed_samples += N
if args.rank == 0 and iter_num % args.print_interval == 0:
end_elapsed_time = time.time()
elapsed_time = end_elapsed_time - start_elapsed_time
avg_samples_per_sec = num_elapsed_samples * args.N_gpu / elapsed_time
print("Epoch:{:4d}, Iteration: {:6d}, Loss function: {:5.3f}, Average Loss: {:.3f},\
avg. samples / sec: {:.2f}".format(epoch, iter_num, loss.item(),\
avg_loss, avg_samples_per_sec), end="\n")
last_printed_iter = iter_num
start_elapsed_time = time.time()
num_elapsed_samples = 0
with amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
optim.step()
optim.zero_grad()
profile.end()
if iter_num % 20 == 0:
finished = check_async_evals(args, evaluator, args.threshold)
if finished:
return True
iter_num += 1
log_end(key=constants.EPOCH_STOP, metadata={'epoch_num': epoch + 1})
return False
def main():
configure_logger(constants.SSD)
log_start(key=constants.INIT_START, log_all_ranks=True)
args = parse_args()
sys.stdout = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))
sys.stderr = Logger("test/output/%s/%s_%s.log"%(args.device_id,args.tag,args.device_id))
if args.local_rank == 0:
print(args)
args.evaluation.sort()
args.lr_decay_epochs.sort()
validate_arguments(args)
torch.set_num_threads(1)
torch.backends.cudnn.benchmark = not args.profile_cudnn_get
success = train300_mlperf_coco(args)
status = 'success' if success else 'aborted'
log_end(key=constants.RUN_STOP, metadata={'status': status})
if __name__ == "__main__":
main()