import argparse
import os
import time
import PIL.Image
import numpy as np
import psutil
import setproctitle
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.distributions as dist
from torch.utils.data import DataLoader
import torch.multiprocessing as mp
import torch.nn.functional as F
from torchvision import models, transforms
import torchvision.datasets as dset
from torch.nn.parallel import DistributedDataParallel as DDP
from apex import amp
import torchvision
import torchvision_npu
torch.npu.set_compile_mode(jit_compile=False)
IMG_RESIZE = 224
IMAGENET_DATASET_PATH = './imagenet/train'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='imagenet')
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--nEpoch', type=int, default=3)
parser.add_argument('--seed', type=int, default=49)
parser.add_argument('--opt', type=str, default='sgd')
parser.add_argument('--world_size', type=int, default=8)
parser.add_argument('--local_rank', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--backend', type=str, default='PIL')
parser.add_argument('--distribute_func', type=str, default='launch')
args = parser.parse_args()
torchvision.set_image_backend(args.backend)
return args
def main():
args = parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
cudnn.deterministic = True
cudnn.benchmark = True
torch.npu.set_compile_mode(jit_compile=False)
if args.distribute_func == 'launch':
main_worker(args=args)
elif args.distribute_func == 'spawn':
mp.spawn(main_worker, nprocs=args.world_size, args=(args,))
def get_transforms():
base_trans = [
transforms.RandomResizedCrop(IMG_RESIZE),
transforms.RandomHorizontalFlip(),
]
return transforms.Compose(base_trans)
def main_worker(rank_id=-1, args=None):
if args.distribute_func == 'spawn':
args.local_rank = rank_id
p = psutil.Process()
cpu_list = p.cpu_affinity()
core_per_proc = len(cpu_list) // args.world_size
p.cpu_affinity(cpu_list[args.local_rank * core_per_proc:(args.local_rank + 1) * core_per_proc])
print('============== use core:{}'.format(
cpu_list[args.local_rank * core_per_proc:(args.local_rank + 1) * core_per_proc]
))
process = '{}_{}_{}'.format(args.network, args.backend, args.local_rank)
print('============= process:{}'.format(process))
setproctitle.setproctitle(process)
if args.world_size > 1:
dist.init_process_group(backend='hccl', world_size=args.world_size, rank=args.local_rank)
if args.network == 'resnet18':
net = models.resnet18(num_classes=1000)
elif args.network == 'resnet50':
net = models.resnet50(num_classes=1000)
elif args.network == 'mobilenetv2':
net = models.MobileNetV2(num_classes=1000)
loc = 'npu:{}'.format(args.local_rank)
torch.npu.set_device(loc)
net = net.to(loc)
if args.opt == 'sgd':
optimizer = build_SGD(
parameters=list(net.named_parameters()),
lr=1.6, momentum=0.9, weight_decay=0.0001
)
elif args.opt == 'adam':
optimizer = optim.Adam(net.parameters(), weight_decay=1e-4)
elif args.opt == 'rmsprop':
optimizer = optim.RMSprop(net.parameters(), weight_decay=1e-4)
net, optimizer = amp.initialize(net, optimizer, opt_level='O2', loss_scale=1024, verbosity=1)
if args.worldsize > 1:
net = DDP(net, device_ids=[args.local_rank], broadcast_buffer=False)
train_transforms = get_transforms()
if args.dataset == 'imagenet':
dataset = dset.ImageFolder(
loader=torchvision_npu.dataset.folder._cv2_loader,
root=IMAGENET_DATASET_PATH,
transform=train_transforms,
)
print('===================== dist')
dataloader_fn = MultiEpochsDataLoader
if args.world_size > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
train_loader = dataloader_fn(
dataset, batch_size=args.batch_size,
num_workers=args.num_workers, pin_memory=True,
sampler=train_sampler, drop_last=True,
collate_fn=fast_collate
)
else:
train_loader = dataloader_fn(
dataset, batch_size=args.batch_size,
num_workers=args.num_workers, shuffle=False,
drop_last=True, pin_memory=False,
collate_fn=fast_collate
)
print('================== loop')
for epoch in range(args.nEpochs):
if args.world_size > 1:
train_sampler.set_epoch(epoch)
train(args, epoch, net, train_loader, optimizer)
def train(args, epoch, net, train_loader, optimizer):
net.train()
local_rank = 'npu:{}'.format(args.local_rank)
e2e_begin = time.time()
count = 50
for i, (img, target) in enumerate(train_loader):
img = img.to(local_rank, non_blocking=True).to(torch.float)
target = target.to(torch.int32).to(local_rank, non_blocking=True)
output = net(img)
loss = F.nll_loss(output, target)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % count == 0 and not i == 0 and args.local_rank == 0:
print('epoch:{}, step:{}, fps:{}'.format(epoch, i, args.batch_size * args.world_size * count / (
time.time() - e2e_begin)))
e2e_begin = time.time()
def build_SGD(parameters, lr, momentum, weight_decay, nesterov=False):
bn_params = [v for n, v in parameters if 'bn' in n]
rest_params = [v for n, v in parameters if 'bn' not in n]
optimizer = torch.optim.SGD(
[{'params': bn_params, 'weight_decay': 0},
{'params': rest_params, 'weight_decay': weight_decay}],
lr, monmentum=momentum, weight_decay=weight_decay, nesterov=nesterov
)
return optimizer
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for _ in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
"""
sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
if isinstance(imgs[0], PIL.Image.Image):
w = imgs[0].size[0]
h = imgs[0].size[1]
else:
w = imgs[0].shape[1]
h = imgs[0].shape[0]
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8).contiguous(
memory_format=torch.contiguous_format
)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if nump_array.ndim < 3:
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array.copy())
return tensor, targets
if __name__ == "__main__":
main()