import os
import argparse
import numpy as np
import cv2
import multiprocessing
from tqdm import tqdm
def resize(img, size):
old_h = img.shape[0]
old_w = img.shape[1]
scale_ratio = 800 / min(old_w, old_h)
if old_h < old_w:
new_h, new_w = 800, int(np.floor(scale_ratio * old_w))
else:
new_h, new_w = int(np.floor(scale_ratio * old_h)), 800
if max(new_h, new_w) > 1333:
scale = 1333 / max(new_h, new_w)
new_h = new_h * scale
new_w = new_w * scale
new_w = int(new_w + 0.5)
new_h = int(new_h + 0.5)
ret = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
return ret
def gen_input_bin(file_batches, batch):
for file in file_batches[batch]:
image = cv2.imread(os.path.join(flags.image_src_path, file),
cv2.IMREAD_COLOR)
image = resize(image, (800, 1333))
mean = np.array([103.53, 116.28, 123.675], dtype=np.float32)
std = np.array([1., 1., 1.], dtype=np.float32)
img = image.copy().astype(np.float32)
mean = np.float64(mean.reshape(1, -1))
std = 1 / np.float64(std.reshape(1, -1))
cv2.subtract(img, mean, img)
cv2.multiply(img, std, img)
img = cv2.copyMakeBorder(img, 0, flags.model_input_height -
img.shape[0], 0, flags.model_input_width - img.shape[1], cv2.BORDER_CONSTANT, value=0)
img = img.transpose(2, 0, 1)
img.tofile(os.path.join(flags.bin_file_path,
file.split('.')[0] + ".bin"))
def preprocess(src_path, save_path):
files = os.listdir(src_path)
file_batches = [files[i:i + 100]
for i in range(0, 5000, 100) if files[i:i + 100] != []]
thread_pool = multiprocessing.Pool(len(file_batches))
pbar = tqdm(range(len(file_batches)))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_input_bin,
args=(file_batches, batch),
callback=lambda _: pbar.update(1),
error_callback=lambda _: pbar.update(1))
thread_pool.close()
thread_pool.join()
print("in thread, except will not report! please ensure bin files generated.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='preprocess of Retinanet PyTorch model')
parser.add_argument("--image_src_path",
default="/root/datasets/coco/val2017",
help='image of dataset')
parser.add_argument("--bin_file_path",
default="./val2017_bin/",
help='Preprocessed image buffer')
parser.add_argument("--model_input_height",
default=1344, type=int,
help='input tensor height')
parser.add_argument("--model_input_width",
default=1344, type=int,
help='input tensor width')
flags = parser.parse_args()
if not os.path.exists(flags.bin_file_path):
os.makedirs(flags.bin_file_path)
preprocess(flags.image_src_path, flags.bin_file_path)