import os
import sys
from PIL import Image
import numpy as np
import multiprocessing
from torchvision import transforms
from tqdm import tqdm
def gen_input_bin(save_path, file_batches, batch):
for file in file_batches[batch]:
image = Image.open(os.path.join(src_path, file)).convert('RGB')
resize = transforms.Resize(256)
image = resize(image)
crop = transforms.CenterCrop(224)
image = crop(image)
tt = transforms.ToTensor()
image = tt(image)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
image = np.array(image, dtype=np.float32)
image.tofile(os.path.join(save_path, file.split('.')[0] + ".bin"))
def preprocess(src_path, save_path):
files = os.listdir(src_path)
file_batches = [files[i:i + 500]
for i in range(0, 50000, 500) if files[i:i + 500] != []]
thread_pool = multiprocessing.Pool(len(file_batches))
pbar = tqdm(range(len(file_batches)))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_input_bin,
args=(save_path, file_batches, batch),
callback=lambda _: pbar.update(1),
error_callback=lambda _: pbar.update(1))
thread_pool.close()
thread_pool.join()
print("in thread, except will not report! please ensure bin files generated.")
if __name__ == '__main__':
src_path = sys.argv[1]
save_path = sys.argv[2]
src_path = os.path.realpath(src_path)
save_path = os.path.realpath(save_path)
if not os.path.isdir(save_path):
os.makedirs(os.path.realpath(save_path))
preprocess(src_path, save_path)