import os
import argparse
from tqdm import tqdm
import torchvision as tv
import torch.utils.data
import torch.nn.functional as F
def preprocess(dataset_path, data_bin_path, label_path, batch_size):
val_tx = tv.transforms.Compose([
tv.transforms.Resize((128, 128)),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
valid_set = tv.datasets.CIFAR10(dataset_path, transform=val_tx, train=False, download=True)
valid_loader = torch.utils.data.DataLoader(
valid_set, batch_size=batch_size, shuffle=False,
pin_memory=True, drop_last=False)
if not os.path.isdir(data_bin_path):
os.mkdir(data_bin_path)
with open(label_path, 'x') as f:
for i, (images, target) in tqdm(enumerate(valid_loader)):
label = ' '.join((str(i) for i in target.tolist()))
f.write(label+'\n')
save_file_name = "{}.bin".format(i)
save_path = os.path.join(data_bin_path, save_file_name)
if images.shape[0] != batch_size:
images = F.pad(input=images, pad=(0, 0, 0, 0, 0, 0, 0, batch_size-images.shape[0]), mode='constant', value=0)
images.numpy().tofile(save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='bit_preprocess')
parser.add_argument('--dataset_path', type=str, help='dataset path', required=True)
parser.add_argument('--save_path', type=str, help='bin file save path', required=True)
parser.add_argument('--label_path', type=str, help='path to save label', required=True)
parser.add_argument('--batch_size', type=int, default=1, help='om batch size')
args = parser.parse_args()
preprocess(args.dataset_path, args.save_path, args.label_path, args.batch_size)