import argparse
from tqdm import tqdm
import torch
from torch import nn
from torch import optim
import apex
from apex import amp
from model import Backbone, Arcface
from utils import separate_bn_paras
def get_data(args):
x = torch.rand((args.batch, 3, 112, 112), dtype=torch.float32)
y = torch.randint(2, (args.batch,), dtype=torch.long)
return x, y
class NewModel(nn.Module):
def __init__(self):
super(NewModel, self).__init__()
self.backbone = Backbone(num_layers=100, drop_ratio=0.6, mode='ir_se')
self.head = Arcface(embedding_size=512, classnum=85742)
def forward(self, images, labels):
embeddings = self.backbone(images)
thetas = self.head(embeddings, labels)
return thetas
def prepare_args():
parser = argparse.ArgumentParser(description='get prof')
parser.add_argument("-device", help="device", default='cuda:0', type=str)
parser.add_argument("-amp", help="use amp", default=True, type=str)
parser.add_argument("-batch", help="batch size", default=256, type=int)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = prepare_args()
device = torch.device(args.device)
if 'npu' in args.device:
torch.npu.set_device(device)
else:
torch.cuda.set_device(device)
model = NewModel()
model = model.to(device)
print('model head create over ')
paras_only_bn, paras_wo_bn = separate_bn_paras(model.backbone)
if 'npu' in args.device and args.amp:
optimizer = apex.optimizers.NpuFusedSGD([
{'params': paras_wo_bn + [model.head.kernel], 'weight_decay': 5e-4},
{'params': paras_only_bn}
], lr=0.001, momentum=0.9)
else:
optimizer = optim.SGD([
{'params': paras_wo_bn + [model.head.kernel], 'weight_decay': 5e-4},
{'params': paras_only_bn}
], lr=0.001, momentum=0.9)
print('optimizer create over')
loss_func = nn.CrossEntropyLoss().to(device)
if 'npu' in args.device and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0, combine_grad=True)
elif 'cuda' in args.device and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128.0)
print('start warm up train')
for _ in tqdm(range(5)):
imgs, labels = get_data(args)
imgs = imgs.to(device)
labels = labels.to(device)
thetas = model(imgs, labels)
loss = loss_func(thetas, labels)
optimizer.zero_grad()
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
print('start get prof')
if "npu" in args.device:
k_v = {'use_npu': True}
else:
k_v = {'use_cuda': True}
with torch.autograd.profiler.profile(**k_v) as prof:
imgs, labels = get_data(args)
imgs = imgs.to(device)
labels = labels.to(device)
thetas = model(imgs, labels)
loss = loss_func(thetas, labels)
optimizer.zero_grad()
if args.amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
prof.export_chrome_trace("output.prof")