from collections import OrderedDict
import torch
import torch.onnx
import argparse
import sys
sys.path.append('./models/models')
import pointnet2_cls_ssg as pointnet2_cls
from pointnet2_utils import farthest_point_sample
from pointnet2_utils import sample_and_group
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('off_line_pred')
parser.add_argument('--target_model', type=int, default=1,
required=True, help='target trans_models')
parser.add_argument('--pth_dir', type=str, default='',
required=False, help='target trans_models')
parser.add_argument('--batch_size', type=int, default=1,
required=False, help='batch size')
return parser.parse_args()
def proc_node_module(checkpoint, AttrName):
new_state_dict = OrderedDict()
for k, v in checkpoint[AttrName].items():
if k[0:7] == "module.":
name = k[7:]
else:
name = k[0:]
new_state_dict[name] = v
return new_state_dict
def model_convert(dir):
experiment_dir = dir
dummy_input = torch.randn(args.batch_size, 3, 1024)
checkpoint = torch.load(str(experiment_dir) + '/best_model.pth',map_location = 'cpu')
checkpoint['model_state_dict'] = proc_node_module(checkpoint,'model_state_dict')
model = pointnet2_cls.get_model_part1(normal_channel=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
npoint = 512
radius = 0.2
nsample = 32
points = None
test_input = dummy_input.permute(0, 2, 1)
centroid = farthest_point_sample(test_input, npoint)
new_xyz, new_points = sample_and_group(npoint, radius, nsample, test_input, points, centroid)
new_points = new_points.permute(0, 3, 2, 1)
input_names = ["xyz", "samp_points"]
output_names = ["l1_xyz", "l1_point"]
torch.onnx.export(model, (new_xyz, new_points),
"Pointnetplus_part1_bs{}.onnx".format(args.batch_size),
input_names=input_names, verbose=True, output_names=output_names, opset_version=11)
def model_convert2(dir):
experiment_dir = dir
dummy_xyz_input = torch.randn(args.batch_size, 3, 512)
dummy_point_input = torch.randn(args.batch_size, 128, 512)
checkpoint = torch.load(str(experiment_dir) + '/best_model.pth',map_location = 'cpu')
checkpoint['model_state_dict'] = proc_node_module(checkpoint,'model_state_dict')
model = pointnet2_cls.get_model_part2(normal_channel=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
npoint = 128
radius = 0.4
nsample = 64
points = None
test_input = dummy_xyz_input.permute(0, 2, 1)
test_points = dummy_point_input.permute(0, 2, 1)
centroid = farthest_point_sample(test_input, npoint)
new_xyz, new_points = sample_and_group(npoint, radius, nsample, test_input, test_points, centroid)
new_points = new_points.permute(0, 3, 2, 1)
new_xyz = new_xyz.permute(0, 2, 1)
input_names = ["l1_xyz", "l1_points"]
output_names = ["class", "l3_point"]
torch.onnx.export(model, (new_xyz, new_points),
"Pointnetplus_part2_bs{}.onnx".format(args.batch_size),
input_names=input_names, verbose=True, output_names=output_names, opset_version=11)
if __name__ == '__main__':
args = parse_args()
if(args.target_model == 1):
model_convert(args.pth_dir)
else:
model_convert2(args.pth_dir)