import os
import sys
import torch
from collections import OrderedDict
from torch import nn
from torch.autograd import Variable
sys.path.append('./utils')
from utils.model import RandPointCNN
from utils.util_funcs import knn_indices_func_gpu, knn_indices_func_cpu
from utils.util_layers import Dense
AbbPointCNN = lambda a, b, c, d, e: RandPointCNN(a, b, 3, c, d, e, knn_indices_func_gpu)
NUM_CLASS = 40
class Classifier(nn.Module):
def __init__(self):
super(Classifier, self).__init__()
self.pcnn1 = AbbPointCNN(3, 32, 8, 1, -1)
self.pcnn2 = nn.Sequential(
AbbPointCNN(32, 64, 8, 2, -1),
AbbPointCNN(64, 96, 8, 4, -1),
AbbPointCNN(96, 128, 12, 4, 120),
AbbPointCNN(128, 160, 12, 6, 120)
)
self.fcn = nn.Sequential(
Dense(160, 128),
Dense(128, 64, drop_rate=0.5),
Dense(64, NUM_CLASS, with_bn=False, activation=None)
)
def forward(self, x):
x = self.pcnn1(x)
x = self.pcnn2(x)[1]
logits = self.fcn(x)
logits_mean = torch.mean(logits, dim=1)
return logits_mean
def proc_nodes_module(checkpoint):
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if "module." in k:
name = k.replace("module.", "")
else:
name = k
new_state_dict[name] = v
return new_state_dict
def pth2onnx(input_path, output_path):
model = Classifier()
checkpoint = torch.load(input_path, map_location='cpu')
checkpoint = proc_nodes_module(checkpoint)
model.load_state_dict(checkpoint)
model.eval()
input_names = ["P_sampled", "P_patched"]
output_names = ["out"]
input_0 = torch.randn((1, 1024, 3), dtype=torch.float32)
input_1 = torch.randn((1, 1024, 3), dtype=torch.float32)
dummy_input = [input_0, input_1]
torch.onnx.export(model, dummy_input, output_path, input_names=input_names,
output_names=output_names, opset_version=11)
if __name__ == "__main__":
input_file = sys.argv[1]
output_file = sys.argv[2]
pth2onnx(input_file, output_file)