05360171创建于 2022年3月18日历史提交
diff --git a/pointnet/dataset.py b/pointnet/dataset.py
index 678460a..f531b65 100644
--- a/pointnet/dataset.py
+++ b/pointnet/dataset.py
@@ -1,13 +1,11 @@
-from __future__ import print_function
 import torch.utils.data as data
 import os
 import os.path
 import torch
 import numpy as np
 import sys
-from tqdm import tqdm 
 import json
-from plyfile import PlyData, PlyElement
+
 
 def get_segmentation_classes(root):
     catfile = os.path.join(root, 'synsetoffset2category.txt')
@@ -27,7 +25,7 @@ def get_segmentation_classes(root):
         for fn in fns:
             token = (os.path.splitext(os.path.basename(fn))[0])
             meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
-    
+
     with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:
         for item in cat:
             datapath = []
@@ -43,6 +41,7 @@ def get_segmentation_classes(root):
             print("category {} num segmentation classes {}".format(item, num_seg_classes))
             f.write("{}\t{}\n".format(item, num_seg_classes))
 
+
 def gen_modelnet_id(root):
     classes = []
     with open(os.path.join(root, 'train.txt'), 'r') as f:
@@ -53,6 +52,7 @@ def gen_modelnet_id(root):
         for i in range(len(classes)):
             f.write('{}\t{}\n'.format(classes[i], i))
 
+
 class ShapeNetDataset(data.Dataset):
     def __init__(self,
                  root,
@@ -68,12 +68,11 @@ class ShapeNetDataset(data.Dataset):
         self.data_augmentation = data_augmentation
         self.classification = classification
         self.seg_classes = {}
-        
+
         with open(self.catfile, 'r') as f:
             for line in f:
                 ls = line.strip().split()
                 self.cat[ls[0]] = ls[1]
-        #print(self.cat)
         if not class_choice is None:
             self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
 
@@ -81,7 +80,6 @@ class ShapeNetDataset(data.Dataset):
 
         self.meta = {}
         splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
-        #from IPython import embed; embed()
         filelist = json.load(open(splitfile, 'r'))
         for item in self.cat:
             self.meta[item] = []
@@ -89,8 +87,9 @@ class ShapeNetDataset(data.Dataset):
         for file in filelist:
             _, category, uuid = file.split('/')
             if category in self.cat.values():
-                self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
-                                        os.path.join(self.root, category, 'points_label', uuid+'.seg')))
+                self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid + '.pts'),
+                                                         os.path.join(self.root, category, 'points_label',
+                                                                      uuid + '.seg')))
 
         self.datapath = []
         for item in self.cat:
@@ -98,34 +97,31 @@ class ShapeNetDataset(data.Dataset):
                 self.datapath.append((item, fn[0], fn[1]))
 
         self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
-        print(self.classes)
         with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
             for line in f:
                 ls = line.strip().split()
                 self.seg_classes[ls[0]] = int(ls[1])
         self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
-        print(self.seg_classes, self.num_seg_classes)
 
     def __getitem__(self, index):
         fn = self.datapath[index]
         cls = self.classes[self.datapath[index][0]]
         point_set = np.loadtxt(fn[1]).astype(np.float32)
         seg = np.loadtxt(fn[2]).astype(np.int64)
-        #print(point_set.shape, seg.shape)
 
         choice = np.random.choice(len(seg), self.npoints, replace=True)
-        #resample
+
         point_set = point_set[choice, :]
 
-        point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
-        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
-        point_set = point_set / dist #scale
+        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)
+        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
+        point_set = point_set / dist
 
         if self.data_augmentation:
-            theta = np.random.uniform(0,np.pi*2)
-            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
-            point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
-            point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter
+            theta = np.random.uniform(0, np.pi * 2)
+            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)
+            point_set += np.random.normal(0, 0.02, size=point_set.shape)
 
         seg = seg[choice]
         point_set = torch.from_numpy(point_set)
@@ -140,11 +136,12 @@ class ShapeNetDataset(data.Dataset):
     def __len__(self):
         return len(self.datapath)
 
+
 class ModelNetDataset(data.Dataset):
     def __init__(self,
                  root,
                  npoints=2500,
-                 split='train',
+                 split='trainval',
                  data_augmentation=True):
         self.npoints = npoints
         self.root = root
@@ -156,7 +153,7 @@ class ModelNetDataset(data.Dataset):
                 self.fns.append(line.strip())
 
         self.cat = {}
-        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
+        with open('./misc/modelnet_id.txt', 'r') as f:
             for line in f:
                 ls = line.strip().split()
                 self.cat[ls[0]] = int(ls[1])
@@ -167,49 +164,48 @@ class ModelNetDataset(data.Dataset):
     def __getitem__(self, index):
         fn = self.fns[index]
         cls = self.cat[fn.split('/')[0]]
+
         with open(os.path.join(self.root, fn), 'rb') as f:
             plydata = PlyData.read(f)
         pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
         choice = np.random.choice(len(pts), self.npoints, replace=True)
         point_set = pts[choice, :]
 
-        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
+        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)
         dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
-        point_set = point_set / dist  # scale
+        point_set = point_set / dist
 
         if self.data_augmentation:
             theta = np.random.uniform(0, np.pi * 2)
             rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
-            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
-            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter
+            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)
+            point_set += np.random.normal(0, 0.02, size=point_set.shape)
 
         point_set = torch.from_numpy(point_set.astype(np.float32))
         cls = torch.from_numpy(np.array([cls]).astype(np.int64))
         return point_set, cls
 
-
     def __len__(self):
         return len(self.fns)
 
+
 if __name__ == '__main__':
     dataset = sys.argv[1]
     datapath = sys.argv[2]
 
     if dataset == 'shapenet':
-        d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
+        d = ShapeNetDataset(root=datapath, class_choice=['Chair'])
         print(len(d))
         ps, seg = d[0]
-        print(ps.size(), ps.type(), seg.size(),seg.type())
+        print(ps.size(), ps.type(), seg.size(), seg.type())
 
-        d = ShapeNetDataset(root = datapath, classification = True)
+        d = ShapeNetDataset(root=datapath, classification=True)
         print(len(d))
         ps, cls = d[0]
-        print(ps.size(), ps.type(), cls.size(),cls.type())
-        # get_segmentation_classes(datapath)
+        print(ps.size(), ps.type(), cls.size(), cls.type())
 
     if dataset == 'modelnet':
         gen_modelnet_id(datapath)
         d = ModelNetDataset(root=datapath)
         print(len(d))
         print(d[0])
-
diff --git a/pointnet/model.py b/pointnet/model.py
index 48de610..d3081e1 100644
--- a/pointnet/model.py
+++ b/pointnet/model.py
@@ -9,7 +9,7 @@ import torch.nn.functional as F
 
 
 class STN3d(nn.Module):
-    def __init__(self):
+    def __init__(self, device):
         super(STN3d, self).__init__()
         self.conv1 = torch.nn.Conv1d(3, 64, 1)
         self.conv2 = torch.nn.Conv1d(64, 128, 1)
@@ -25,6 +25,7 @@ class STN3d(nn.Module):
         self.bn4 = nn.BatchNorm1d(512)
         self.bn5 = nn.BatchNorm1d(256)
 
+        self.device = device
 
     def forward(self, x):
         batchsize = x.size()[0]
@@ -38,23 +39,28 @@ class STN3d(nn.Module):
         x = F.relu(self.bn5(self.fc2(x)))
         x = self.fc3(x)
 
-        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
-        if x.is_cuda:
+        iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
+            batchsize, 1)
+        if self.device == 'gpu':
             iden = iden.cuda()
+        elif self.device == 'npu':
+            iden = iden.npu()
+        else:
+            iden = iden.cpu()
         x = x + iden
         x = x.view(-1, 3, 3)
         return x
 
 
 class STNkd(nn.Module):
-    def __init__(self, k=64):
+    def __init__(self, k=64, device='cpu'):
         super(STNkd, self).__init__()
         self.conv1 = torch.nn.Conv1d(k, 64, 1)
         self.conv2 = torch.nn.Conv1d(64, 128, 1)
         self.conv3 = torch.nn.Conv1d(128, 1024, 1)
         self.fc1 = nn.Linear(1024, 512)
         self.fc2 = nn.Linear(512, 256)
-        self.fc3 = nn.Linear(256, k*k)
+        self.fc3 = nn.Linear(256, k * k)
         self.relu = nn.ReLU()
 
         self.bn1 = nn.BatchNorm1d(64)
@@ -64,6 +70,7 @@ class STNkd(nn.Module):
         self.bn5 = nn.BatchNorm1d(256)
 
         self.k = k
+        self.device = device
 
     def forward(self, x):
         batchsize = x.size()[0]
@@ -77,17 +84,23 @@ class STNkd(nn.Module):
         x = F.relu(self.bn5(self.fc2(x)))
         x = self.fc3(x)
 
-        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
-        if x.is_cuda:
+        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
+            batchsize, 1)
+        if self.device == 'gpu':
             iden = iden.cuda()
+        elif self.device == 'npu':
+            iden = iden.npu()
+        else:
+            iden = iden.cpu()
         x = x + iden
         x = x.view(-1, self.k, self.k)
         return x
 
+
 class PointNetfeat(nn.Module):
-    def __init__(self, global_feat = True, feature_transform = False):
+    def __init__(self, global_feat=True, feature_transform=False, device='cpu'):
         super(PointNetfeat, self).__init__()
-        self.stn = STN3d()
+        self.stn = STN3d(device)
         self.conv1 = torch.nn.Conv1d(3, 64, 1)
         self.conv2 = torch.nn.Conv1d(64, 128, 1)
         self.conv3 = torch.nn.Conv1d(128, 1024, 1)
@@ -97,7 +110,7 @@ class PointNetfeat(nn.Module):
         self.global_feat = global_feat
         self.feature_transform = feature_transform
         if self.feature_transform:
-            self.fstn = STNkd(k=64)
+            self.fstn = STNkd(k=64, device=device)
 
     def forward(self, x):
         n_pts = x.size()[2]
@@ -109,9 +122,9 @@ class PointNetfeat(nn.Module):
 
         if self.feature_transform:
             trans_feat = self.fstn(x)
-            x = x.transpose(2,1)
+            x = x.transpose(2, 1)
             x = torch.bmm(x, trans_feat)
-            x = x.transpose(2,1)
+            x = x.transpose(2, 1)
         else:
             trans_feat = None
 
@@ -126,11 +139,12 @@ class PointNetfeat(nn.Module):
             x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
             return torch.cat([x, pointfeat], 1), trans, trans_feat
 
+
 class PointNetCls(nn.Module):
-    def __init__(self, k=2, feature_transform=False):
+    def __init__(self, k=2, feature_transform=False, device='cpu'):
         super(PointNetCls, self).__init__()
         self.feature_transform = feature_transform
-        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
+        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform, device=device)
         self.fc1 = nn.Linear(1024, 512)
         self.fc2 = nn.Linear(512, 256)
         self.fc3 = nn.Linear(256, k)
@@ -148,11 +162,11 @@ class PointNetCls(nn.Module):
 
 
 class PointNetDenseCls(nn.Module):
-    def __init__(self, k = 2, feature_transform=False):
+    def __init__(self, k=2, feature_transform=False, device='cpu'):
         super(PointNetDenseCls, self).__init__()
         self.k = k
-        self.feature_transform=feature_transform
-        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
+        self.feature_transform = feature_transform
+        self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform, device=device)
         self.conv1 = torch.nn.Conv1d(1088, 512, 1)
         self.conv2 = torch.nn.Conv1d(512, 256, 1)
         self.conv3 = torch.nn.Conv1d(256, 128, 1)
@@ -169,22 +183,28 @@ class PointNetDenseCls(nn.Module):
         x = F.relu(self.bn2(self.conv2(x)))
         x = F.relu(self.bn3(self.conv3(x)))
         x = self.conv4(x)
-        x = x.transpose(2,1).contiguous()
-        x = F.log_softmax(x.view(-1,self.k), dim=-1)
+        x = x.transpose(2, 1).contiguous()
+        x = F.log_softmax(x.view(-1, self.k), dim=-1)
         x = x.view(batchsize, n_pts, self.k)
         return x, trans, trans_feat
 
-def feature_transform_regularizer(trans):
+
+def feature_transform_regularizer(trans, device):
     d = trans.size()[1]
     batchsize = trans.size()[0]
     I = torch.eye(d)[None, :, :]
-    if trans.is_cuda:
+    if device == 'gpu':
         I = I.cuda()
-    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
+    elif device == 'npu':
+        I = I.npu()
+    else:
+        I = I.cpu()
+    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2)))
     return loss
 
+
 if __name__ == '__main__':
-    sim_data = Variable(torch.rand(32,3,2500))
+    sim_data = Variable(torch.rand(32, 3, 2500))
     trans = STN3d()
     out = trans(sim_data)
     print('stn', out.size())
@@ -204,10 +224,10 @@ if __name__ == '__main__':
     out, _, _ = pointfeat(sim_data)
     print('point feat', out.size())
 
-    cls = PointNetCls(k = 5)
+    cls = PointNetCls(k=5)
     out, _, _ = cls(sim_data)
     print('class', out.size())
 
-    seg = PointNetDenseCls(k = 3)
+    seg = PointNetDenseCls(k=3)
     out, _, _ = seg(sim_data)
     print('seg', out.size())