import argparse
import sys
import os
import torch
import numpy as np
from tqdm import tqdm
from torchvision import datasets, transforms
from facenet_pytorch import fixed_image_standardization
from torch.utils.data import DataLoader, SequentialSampler
def face_preprocess(crop_dir, save_dir):
trans = transforms.Compose([
np.float32,
transforms.ToTensor(),
fixed_image_standardization
])
dataset = datasets.ImageFolder(crop_dir, transform=trans)
embed_loader = DataLoader(
dataset,
num_workers=workers,
batch_size=batch_size,
sampler=SequentialSampler(dataset)
)
for idx, (xb, yb) in tqdm(enumerate(embed_loader)):
out_path_xb = os.path.join(save_dir, 'xb_results', '{}.bin'.format(idx))
out_path_yb = os.path.join(save_dir, 'yb_results', '{}.bin'.format(idx))
os.makedirs(os.path.dirname(out_path_xb), exist_ok=True)
os.makedirs(os.path.dirname(out_path_yb), exist_ok=True)
if xb.shape[0] < batch_size:
xb_zeros = np.zeros([batch_size - int(xb.shape[0]), int(xb.shape[1]), int(xb.shape[2]), int(xb.shape[3])])
xb = np.concatenate([xb.numpy(), xb_zeros], axis=0)
xb = torch.from_numpy(xb)
xb.detach().cpu().numpy().tofile(out_path_xb)
yb.detach().cpu().numpy().tofile(out_path_yb)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--crop_dir', type=str, help='cropped image save path')
parser.add_argument('--save_dir', type=str, help='preprocess bin files save path')
parser.add_argument('--batch_size', type=int, default=1, help='preprocess bin files save path')
arg = parser.parse_args()
batch_size = arg.batch_size
workers = 0 if os.name == 'nt' else 8
face_preprocess(arg.crop_dir, arg.save_dir)