import torch
import numpy as np
import os
import sys
import provider
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'models'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
TEST_FILES = provider.getDataFiles(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))
BATCH_SIZE = 1
def preprocess(save_path,label_save_path):
i = 0
test_file_idxs = np.arange(0, len(TEST_FILES))
for fn in range(len(TEST_FILES)):
current_data, current_label = provider.loadDataFile(TEST_FILES[test_file_idxs[fn]])
current_data = current_data[:, 0:1024, :]
file_size = current_data.shape[0]
num_batches = file_size // BATCH_SIZE
for batch_idx in range(num_batches):
i += 1
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx + 1) * BATCH_SIZE
label = current_label[start_idx:end_idx]
rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
jittered_data = provider.jitter_point_cloud(rotated_data)
P_sampled = np.array(torch.from_numpy(jittered_data).float(), dtype=np.float32)
P_sampled.tofile(os.path.join(save_path, "data" +str(i) + ".bin"))
np.save(os.path.join(label_save_path,'label'+str(i)),label)
if __name__ == "__main__":
save_path = sys.argv[1]
label_save_path = sys.argv[2]
save_path = os.path.realpath(save_path)
if not os.path.isdir(save_path):
os.makedirs(os.path.realpath(save_path))
if not os.path.isdir(label_save_path):
os.makedirs(os.path.realpath(label_save_path))
preprocess( save_path,label_save_path)