from __future__ import absolute_import, division, print_function
import argparse
import os
import parse
from tqdm import tqdm
import numpy as np
def dump_input_data(save_dir, input_data, seq):
input_names = input_data[0].keys()
for input_name in input_names:
sub_dir = os.path.join(save_dir, input_name)
os.makedirs(sub_dir, exist_ok=True)
for data_idx in tqdm(range(len(input_data))):
data_dic = input_data[data_idx]
for data_name, data in data_dic.items():
data = data[:, :seq]
save_path = os.path.join(save_dir, data_name, f"{data_idx}.npy")
data = data.numpy()
np.save(save_path, data)
def dump_label(save_dir, gt_label):
gt_label = [label.numpy() for label in gt_label]
gt_label = np.array(gt_label)
save_path = os.path.join(save_dir, "label.npy")
np.save(save_path, gt_label)
def om_pre(ar):
ar.batch_size = 1
data, _, label = parse.load_data_model(ar)
dump_input_data(ar.save_dir, data, ar.max_seq_length)
dump_label(ar.save_dir, label)
print('data num: %d' % len(data))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prefix_dir", type=str, default='./albert_pytorch',
help="prefix dir for ori model code")
parser.add_argument("--pth_dir", type=str, default='./albert_pytorch/outputs/SST-2/',
help="dir of pth, load args.bin and model.bin")
parser.add_argument("--data_dir", type=str, default='./albert_pytorch/dataset/SST-2/',
help="dir of dataset")
parser.add_argument("--save_dir", type=str, default='',
help="save dir for preprocessed data")
parser.add_argument("--max_seq_length", type=int, default=128,
help="seq length for input data.")
ar = parser.parse_args()
ar.pth_arg_path = os.path.join(ar.pth_dir, "training_args.bin")
ar.data_type = 'dev'
om_pre(ar)
if __name__ == "__main__":
main()