import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from resnet import resnet50
import moco.loader
import moco.builder
from apex import amp
import sys
def flush_print(func):
def new_print(*args, **kwargs):
func(*args, **kwargs)
sys.stdout.flush()
return new_print
print = flush_print(print)
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
help='learning rate schedule (when to drop lr by 10x)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum of SGD solver')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--world-size', default=-1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU or NPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--moco-dim', default=128, type=int,
help='feature dimension (default: 128)')
parser.add_argument('--moco-k', default=65536, type=int,
help='queue size; number of negative keys (default: 65536)')
parser.add_argument('--moco-m', default=0.999, type=float,
help='moco momentum of updating key encoder (default: 0.999)')
parser.add_argument('--moco-t', default=0.07, type=float,
help='softmax temperature (default: 0.07)')
parser.add_argument('--mlp', action='store_true',
help='use mlp head')
parser.add_argument('--aug-plus', action='store_true',
help='use moco v2 data augmentation')
parser.add_argument('--cos', action='store_true',
help='use cosine lr schedule')
parser.add_argument("--amp", action="store_true",
help="Use amp to train the model")
parser.add_argument("--loss_scale", default=8.0,
type=float, help="amp setting: loss scale, default 8.0")
parser.add_argument("--opt_level", default="O1", type=str,
help="amp setting: opt level, default O1")
parser.add_argument("--device", default="npu", type=str,
help="device: npu or gpu or cpu")
parser.add_argument('--nDevices', default=8, type=int,
help='number of GPUs or NPUs')
def main():
args = parser.parse_args()
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU or NPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
print(args)
ngpus_per_node = args.nDevices
if args.multiprocessing_distributed:
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
main_worker(args.gpu, ngpus_per_node, args)
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
if args.multiprocessing_distributed and args.gpu != 0:
def print_pass(*args):
pass
builtins.print = print_pass
if args.gpu is not None:
print("Use {}: {} for training".format(args.device, args.gpu))
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
loc = "cpu"
device = torch.device(loc)
if args.device == "gpu":
loc = "cuda:{}".format(args.gpu)
device = torch.device(loc)
torch.cuda.set_device(device)
elif args.device == "npu":
loc = "npu:{}".format(args.gpu)
device = torch.device(loc)
torch.npu.set_device(device)
print("=> creating model '{}'".format(args.arch))
if args.distributed:
model = moco.builder.MoCo(
resnet50,
args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, dist=True, device=args.device)
model.to(device)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
elif args.gpu is not None:
model = moco.builder.MoCo(
resnet50,
args.moco_dim, args.moco_k, args.moco_m, args.moco_t, args.mlp, dist=False, device=args.device)
model.to(device)
else:
raise NotImplementedError("Only DistributedDataParallel is supported.")
print(args)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], broadcast_buffers=False)
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
checkpoint = torch.load(args.resume, map_location=loc)
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True
traindir = os.path.join(args.data, 'train')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if args.aug_plus:
augmentation = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([.1, 2.])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
else:
augmentation = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomGrayscale(p=0.2),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
train_dataset = datasets.ImageFolder(
traindir,
moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch, args)
train(train_loader, model, criterion, optimizer, epoch, args, device)
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, is_best=False, filename="model_moco_epoch_{}.pth.tar".format(epoch+1))
def train(train_loader, model, criterion, optimizer, epoch, args, device):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
FPS = AverageMeter('FPS', ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5, FPS],
prefix="Epoch: [{}]".format(epoch))
model.train()
if args.device == 'npu':
torch.npu.synchronize()
elif args.device == 'cuda':
torch.cuda.synchronize()
end = time.time()
prof_flag = True
for i, (images, _) in enumerate(train_loader):
if args.device == 'npu':
torch.npu.synchronize()
elif args.device == 'cuda':
torch.cuda.synchronize()
data_time.update(time.time() - end)
if args.gpu is not None:
images[0] = images[0].to(device, non_blocking=True)
images[1] = images[1].to(device, non_blocking=True)
if not args.distributed:
if i == 5 and args.device == "npu" and prof_flag:
with torch.autograd.profiler.profile(use_npu=True) as prof:
output, target = model(im_q=images[0], im_k=images[1])
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images[0].size(0))
top1.update(acc1[0], images[0].size(0))
top5.update(acc5[0], images[0].size(0))
optimizer.zero_grad()
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
prof_file = open("npu_output_moco.txt", 'w')
print(prof.key_averages().table(sort_by="self_cpu_time_total"), file=prof_file)
prof.export_chrome_trace("npu_output_moco.prof")
prof_flag = False
elif i == 5 and args.device == "gpu" and prof_flag:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
output, target = model(im_q=images[0], im_k=images[1])
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images[0].size(0))
top1.update(acc1[0], images[0].size(0))
top5.update(acc5[0], images[0].size(0))
optimizer.zero_grad()
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
prof_file = open("cuda_output_moco.txt", 'w')
print(prof.key_averages().table(sort_by="self_cpu_time_total"), file=prof_file)
prof.export_chrome_trace("cuda_output_moco.prof")
prof_flag = False
else:
output, target = model(im_q=images[0], im_k=images[1])
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images[0].size(0))
top1.update(acc1[0], images[0].size(0))
top5.update(acc5[0], images[0].size(0))
optimizer.zero_grad()
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
if args.device == 'npu':
torch.npu.synchronize()
elif args.device == 'cuda':
torch.cuda.synchronize()
batch_time.update(time.time() - end)
if args.device == 'npu':
torch.npu.synchronize()
elif args.device == 'cuda':
torch.cuda.synchronize()
end = time.time()
FPS.update(args.batch_size * args.nDevices / batch_time.avg)
if i % args.print_freq == 0:
progress.display(i)
def save_checkpoint(state, is_best, filename='model_moco_latest.pth.tar'):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_moco_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.cos:
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
else:
for milestone in args.schedule:
lr *= 0.1 if epoch >= milestone else 1.
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
if __name__ == '__main__':
main()