import os
import sys
import argparse
import multiprocessing
import cv2
import numpy as np
from tqdm import tqdm
model_config = {
'resize': 224,
'centercrop': 224,
'mean': [123.675, 116.28, 103.53]
}
def center_crop(img, out_height, out_width):
height, width, _ = img.shape
left = int((width - out_width) / 2)
right = int((width + out_width) / 2)
top = int((height - out_height) / 2)
bottom = int((height + out_height) / 2)
img = img[top:bottom, left:right]
return img
def resize_with_aspectratio(img, size, scale=87.5, inter_pol=cv2.INTER_LINEAR):
if isinstance(size, int):
height, width, _ = img.shape
new_height = int(100. * size / scale)
new_width = int(100. * size / scale)
if height > width:
w = new_width
h = int(new_height * height / width)
else:
h = new_height
w = int(new_width * width / height)
img = cv2.resize(img, (w, h), interpolation=inter_pol)
return img
else:
img = img.resize(size[::-1], interpolation)
return img
def gen_input_bin(file_batches, batch, src_path, save_path):
for file_name in tqdm(file_batches[batch]):
image = cv2.imread(os.path.join(src_path, file_name))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = resize_with_aspectratio(image, model_config['resize'], inter_pol=cv2.INTER_AREA)
img = center_crop(img, model_config['centercrop'], model_config['centercrop'])
img = np.asarray(img, dtype='uint8')
img.tofile(os.path.join(save_path, file_name.split('.')[0] + ".bin"))
def preprocess(source_path, dest_path):
files = os.listdir(source_path)
files.sort()
if len(files) < 500:
file_batches = [files[0: len(files)]]
else:
file_batches = [files[i:i + 500] for i in range(0, len(files), 500) if files[i:i + 500] != []]
thread_pool = multiprocessing.Pool(len(file_batches))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_input_bin, args=(file_batches, batch, source_path, dest_path))
thread_pool.close()
thread_pool.join()
print("in thread, except will not report! please ensure bin files generated.")
def amct_input_bin(src_path, save_path):
in_files = os.listdir(src_path)
image_names = in_files[0: 64]
data = []
for image_name in tqdm(image_names):
file_path = os.path.join(src_path, image_name)
image = cv2.imread(file_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
img = resize_with_aspectratio(image, model_config['resize'], inter_pol=cv2.INTER_AREA)
img = center_crop(img, model_config['centercrop'], model_config['centercrop'])
img = np.asarray(img, dtype='float32')
img -= np.array(model_config['mean'], dtype='float32')
img = img.transpose([2, 0, 1])
data.append(img)
batch_data = np.stack(data, axis=0)
batch_data.tofile(os.path.join(save_path, image_name.split('.')[0] + ".bin"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src_path', type=str, default='./ImageNet/val', help='path to images.')
parser.add_argument('--save_path', type=str, default='./rep_dataset', help='path to save bin files.')
parser.add_argument('--amct', action='store_true', help='if True, will generate quantization data.')
args = parser.parse_args()
if not os.path.isdir(args.save_path):
os.makedirs(os.path.realpath(args.save_path))
if args.amct:
amct_input_bin(args.src_path, args.save_path)
else:
preprocess(args.src_path, args.save_path)
if __name__ == '__main__':
main()