import os
import sys
import multiprocessing
from PIL import Image
import numpy as np
from tqdm import tqdm
transforms = {
'resize': 256,
'centercrop': 224,
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
}
def center_crop(img, output_size):
if isinstance(output_size, int):
output_size = (int(output_size), int(output_size))
image_width, image_height = img.size
crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
img = img.crop((crop_left, crop_top, crop_left + \
crop_width, crop_top + crop_height))
return img
def resize(img, size, interpolation=Image.BILINEAR):
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
def gen_input_bin(file_batches, batch):
for file in file_batches[batch]:
print(file)
image = Image.open(os.path.join(src_path, file)).convert('RGB')
image = resize(image, transforms['resize'])
image = center_crop(image, transforms['centercrop'])
img = np.array(image, dtype=np.float32)
img = img.transpose(2, 0, 1)
img = img / 255.
img -= np.array(transforms['mean'], dtype=np.float32)[:, None, None]
img /= np.array(transforms['std'], dtype=np.float32)[:, None, None]
img.tofile(os.path.join(save_path, file.split('.')[0] + ".bin"))
def preprocess(src_path, save_path):
files = os.listdir(src_path)
file_batches = [
files[i:i + 500] for i in range(0, 50000, 500)
if files[i:i + 500] != []
]
thread_pool = multiprocessing.Pool(len(file_batches))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_input_bin, args=(file_batches, batch))
thread_pool.close()
thread_pool.join()
print("Done! please ensure bin files generated.")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('data preprocess.')
parser.add_argument('--src_path', type=str, required=True,
help='path to original dataset.')
parser.add_argument('--save_path', type=str, required=True,
help='a directory to save bin files.')
args = parser.parse_args()
src_path = os.path.realpath(args.src_path)
save_path = os.path.realpath(args.save_path)
if not os.path.isdir(save_path):
os.makedirs(save_path)
preprocess(src_path, save_path)