import argparse
import os
import multiprocessing
import ast
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoTokenizer
def load_imagenet_classnames(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
d = ast.literal_eval(f.read())
if len(d) != 1000:
raise ValueError(f"Expected 1000 classes, got {len(d)}")
return [d[i] for i in range(1000)]
def gen_image_input_bin(file_batches, batch, input_dir, image_output_dir, pytorch_ckpt_path):
processor = AutoImageProcessor.from_pretrained(pytorch_ckpt_path, use_fast=False)
for file_name in file_batches[batch]:
with Image.open(os.path.join(input_dir, file_name)) as pilimg:
pilimg = pilimg.convert("RGB")
img_numpy = processor(images=pilimg, return_tensors="np")
img_numpy["pixel_values"].tofile(os.path.join(image_output_dir, file_name.split('.')[0] + ".bin"))
def gen_text_input_bin(text_output_dir, pytorch_ckpt_path, classnames_file):
processor = AutoTokenizer.from_pretrained(pytorch_ckpt_path)
candidate_labels = load_imagenet_classnames(classnames_file)
hypothesis_template = "This is a photo of {}."
sequences = [hypothesis_template.format(x) for x in candidate_labels]
text_numpy = processor(sequences, padding="max_length", max_length=64, truncation=True, return_tensors="np")
text_numpy["input_ids"].tofile(os.path.join(text_output_dir, "IMAGENET_CLASSNAMES_10000.bin"))
def preprocess(input_dir, image_output_dir, text_output_dir, pytorch_ckpt_path, classnames_file):
file_names = os.listdir(input_dir)
total_nums = len(file_names)
batch_size = max(1, total_nums // 10)
file_batches = [file_names[i:i + batch_size] for i in range(0, total_nums, batch_size) if file_names[i:i + batch_size]]
pbar = tqdm(total=len(file_batches))
pbar.set_description("Preprocessing")
thread_pool = multiprocessing.Pool(len(file_batches))
for batch in range(len(file_batches)):
thread_pool.apply_async(gen_image_input_bin, args=(file_batches, batch, input_dir, image_output_dir, pytorch_ckpt_path), callback=lambda *args: pbar.update(),
error_callback=lambda e: print(f"Process error: {e}"))
thread_pool.close()
thread_pool.join()
gen_text_input_bin(text_output_dir, pytorch_ckpt_path, classnames_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Image and text preprocessing script")
parser.add_argument("--data_dir", type=str, required=True, help="Directory of the image dataset")
parser.add_argument("--image_save_dir", type=str, required=True, help="Directory to save processed image")
parser.add_argument("--text_save_dir", type=str, required=True, help="Directory to save processed text")
parser.add_argument("--pytorch_ckpt_path", type=str, required=True, help="Path to the PyTorch model checkpoint")
parser.add_argument("--classnames_file", type=str, required=True, help="Plain-text class names file")
args = parser.parse_args()
if not os.path.isdir(args.image_save_dir):
os.makedirs(os.path.realpath(args.image_save_dir))
if not os.path.isdir(args.text_save_dir):
os.makedirs(os.path.realpath(args.text_save_dir))
preprocess(args.data_dir, args.image_save_dir, args.text_save_dir, args.pytorch_ckpt_path, args.classnames_file)