import os
import argparse
import numpy as np
import cv2
import mmcv
import torch
import pickle as pk
import multiprocessing
from tqdm import tqdm
flags = None
def resize(img, size):
old_h = img.shape[0]
old_w = img.shape[1]
scale_ratio = min(size[0] / old_w, size[1] / old_h)
new_w = int(np.floor(old_w * scale_ratio))
new_h = int(np.floor(old_h * scale_ratio))
resized_img = mmcv.imresize(img, (new_w, new_h))
return resized_img, scale_ratio
def gen_input_bin(file_batches, batch):
for file in file_batches[batch]:
image = mmcv.imread(os.path.join(flags.image_src_path, file))
ori_shape = image.shape
image, scale_factor = resize(
image, (flags.model_input_width, flags.model_input_height))
img_shape = image.shape
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
image = mmcv.imnormalize(image, mean, std)
h = image.shape[0]
w = image.shape[1]
pad_left = (flags.model_input_width - w) // 2
pad_top = (flags.model_input_height - h) // 2
pad_right = flags.model_input_width - pad_left - w
pad_bottom = flags.model_input_height - pad_top - h
image = cv2.copyMakeBorder(
image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
image = image.transpose(2, 0, 1)
image.tofile(os.path.join(flags.bin_file_path,
file.split('.')[0] + ".bin"))
image_meta = {'img_shape': img_shape,
'scale_factor': scale_factor,
'ori_shape': ori_shape}
with open(os.path.join(flags.meta_file_path, file.split('.')[0] + ".pk"), "wb") as fp:
pk.dump(image_meta, fp)
def preprocess():
files = os.listdir(flags.image_src_path)
file_batches = [files[i:i + 100]
for i in range(0, 5000, 100) if files[i:i + 100] != []]
thread_pool = multiprocessing.Pool(len(file_batches))
pbar = tqdm(range(len(file_batches)))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_input_bin,
args=(file_batches, batch),
callback=lambda _: pbar.update(1),
error_callback=lambda _: pbar.update(1))
thread_pool.close()
thread_pool.join()
print("in thread, except will not report! please ensure bin files generated.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='preprocess of MaskRCNN PyTorch model')
parser.add_argument(
"--image_src_path", default="/root/datasets/coco/val2017", help='image of dataset')
parser.add_argument("--bin_file_path", default="val2017_bin",
help='Preprocessed image buffer')
parser.add_argument("--meta_file_path",
default="val2017_bin_meta", help='Get image meta')
parser.add_argument("--model_input_height", default=800,
type=int, help='input tensor height')
parser.add_argument("--model_input_width", default=1216,
type=int, help='input tensor width')
flags = parser.parse_args()
if not os.path.exists(flags.bin_file_path):
os.makedirs(flags.bin_file_path)
if not os.path.exists(flags.meta_file_path):
os.makedirs(flags.meta_file_path)
preprocess()