import sys
sys.path.append(r"./SpanBERT/code")
import argparse
import collections
import json
import logging
import re
from io import open
import numpy as np
import torch
import os
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from run_squad import *
from pytorch_pretrained_bert.tokenization import (BertTokenizer,
whitespace_tokenize)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def main(args):
tokenizer = BertTokenizer.from_pretrained(
args.model, do_lower_case=args.do_lower_case)
eval_examples = read_squad_examples(
input_file=args.dev_file, is_training=False,
version_2_with_negative=args.version_2_with_negative)
eval_features = convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size)
input_ids_path = "./input_ids"
input_mask_path = "./input_mask"
segment_ids_path = "./segment_ids"
if not os.path.exists(input_ids_path):
os.makedirs(input_ids_path)
if not os.path.exists(input_mask_path):
os.makedirs(input_mask_path)
if not os.path.exists(segment_ids_path):
os.makedirs(segment_ids_path)
i = -1
for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, leave = False):
i = i + 1
input_ids_np = input_ids.numpy()
input_mask_np = input_mask.numpy()
segment_ids_np = segment_ids.numpy()
input_ids_np.tofile(os.path.join(input_ids_path, str(i) + '.bin'))
segment_ids_np.tofile(os.path.join(segment_ids_path, str(i) + '.bin'))
input_mask_np.tofile(os.path.join(input_mask_path, str(i) + '.bin'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='spanbert-base-cased', type=str)
parser.add_argument("--dev_file", default='dev-v1.1.json', type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--max_seq_length", default=512, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, "
"how much stride to take between chunks.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument('--version_2_with_negative', action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument("--batch_size", default=1, type=int,
help="When splitting up a long document into chunks, "
"how much stride to take between chunks.")
args = parser.parse_args()
main(args)