import os
import sys
import torch
if torch.__version__>= '1.8.1':
import torch_npu
import numpy as np
import datetime
import logging
import models.provider as provider
import importlib
import shutil
import argparse
import apex
import time
from pathlib import Path
from tqdm import tqdm
from models.ModelNetDataLoader import ModelNetDataLoader
from apex import amp
import torch.distributed as dist
CALCULATE_DEVICE = "npu:0"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('training')
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--model', default='pointnet2_cls_ssg', help='model name [default: pointnet2_cls_ssg]')
parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
parser.add_argument('--log_dir', type=str, default='pointnet2_cls_ssg', help='experiment root')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--process_data', action='store_true', default=True, help='save data offline')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--npu', default=None,type=int,help='NPU id to use.')
parser.add_argument('--world_size',default=-1,type=int, help='number of nodes for distributed training')
parser.add_argument('--rank',default=-1,type=int,help='node rank for distributed training')
parser.add_argument('--dist_backend',default='nccl',type=str,help='distributed backend')
parser.add_argument('--dist_url',default='env://',type=str,help='url used to set up distributed training')
parser.add_argument('--store_prof',action='store_true', default=False, help='save_prof')
parser.add_argument('--data',type=str, default='./modelnet40_normal_resampled', help='data_path')
return parser.parse_args()
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
def test(model, loader, num_class=40):
mean_correct = []
class_acc = np.zeros((num_class, 3))
classifier = model.eval()
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
if not args.use_cpu:
points, target = points.npu(), target.npu()
points = points.transpose(2, 1)
pred, _ = classifier(points)
pred_choice = pred.data.max(1)[1]
for cat in np.unique(target.cpu()):
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.count += n
self.sum += val * n
self.avg = self.sum / self.count
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('classification')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
exp_dir = exp_dir.joinpath(timestr)
else:
exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
if args.npu is None:
args.npu = 0
global CALCULATE_DEVICE
CALCULATE_DEVICE = "npu:{}".format(args.npu)
torch.npu.set_device(CALCULATE_DEVICE)
print("use ", CALCULATE_DEVICE)
data_path = args.data
train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)
'''MODEL LOADING'''
num_class = args.num_category
model = importlib.import_module(args.model)
shutil.copy('./models/%s.py' % args.model, str(exp_dir))
shutil.copy('models/pointnet2_utils.py', str(exp_dir))
shutil.copy('./train_classification_1P.py', str(exp_dir))
classifier = model.get_model(num_class, normal_channel=args.use_normals).npu()
criterion = model.get_loss()
classifier.apply(inplace_relu)
if not args.use_cpu:
classifier = classifier.npu()
criterion = criterion.npu()
log_string('No existing model, starting training from scratch...')
start_epoch = 0
'''MODEL LOADING'''
if args.optimizer == 'Adam':
optimizer = apex.optimizers.NpuFusedAdam(
classifier.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
else:
optimizer = apex.optimiziers.NpuFusedSGD(classifier.parameters(), lr=0.01, momentum=0.9)
classifier,optimizer = amp.initialize(classifier,optimizer,opt_level="O2",loss_scale = 128,combine_grad=True)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
global_epoch = 0
global_step = 0
best_instance_acc = 0.0
best_class_acc = 0.0
'''TRANING'''
logger.info('Start training...')
for epoch in range(start_epoch, args.epoch):
log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
error_cnt = 0
mean_correct = []
classifier = classifier.train()
batch_time = AverageMeter()
data_time = AverageMeter()
tot_time = AverageMeter()
end = time.time()
scheduler.step()
for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
optimizer.zero_grad()
data_time.update(time.time() - end)
points = points.data.numpy()
points = provider.random_point_dropout(points)
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points = points.transpose(2, 1)
if not args.use_cpu:
points, target = points.npu(), target.npu()
'''
if batch_id > 4 and args.store_prof == True:
with torch.autograd.profiler.profile(use_npu=True) as prof:
pred, trans_feat = classifier(points)
loss = criterion(pred, target.long(), trans_feat)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
with amp.scale_loss(loss,optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
global_step += 1
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
prof.export_chrome_trace("output_prof_{}".format(epoch))
args.store_prof == False
'''
pred, trans_feat = classifier(points)
loss = criterion(pred, target.long(), trans_feat)
error_cnt += loss.item()
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
with amp.scale_loss(loss,optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
global_step += 1
current_batch_time = time.time() - end
batch_time.update(current_batch_time)
end = time.time()
FPS = args.batch_size / current_batch_time
if batch_id > 4:
log_string("Epoch %d step %d FPS: %f" % (epoch, batch_id, FPS))
tot_time.update(current_batch_time)
epoch_FPS = args.batch_size / tot_time.avg
log_string("Epoch %d avg FPS: %f" % (epoch, epoch_FPS))
log_string("Epoch %d train loss: %f " % (epoch, error_cnt / batch_id))
train_instance_acc = np.mean(mean_correct)
log_string('Train Instance Accuracy: %f' % train_instance_acc)
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
if (instance_acc >= best_instance_acc):
best_instance_acc = instance_acc
best_epoch = epoch + 1
if (class_acc >= best_class_acc):
best_class_acc = class_acc
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
log_string('Acc@1: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
if (instance_acc >= best_instance_acc):
logger.info('Save model...')
savepath = str(checkpoints_dir) + '/best_model.pth'
log_string('Saving at %s' % savepath)
state = {
'epoch': best_epoch,
'instance_acc': instance_acc,
'class_acc': class_acc,
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
global_epoch += 1
logger.info('End of training...')
if __name__ == '__main__':
args = parse_args()
main(args)