from __future__ import absolute_import, division, print_function
import sys
sys.path.append(r"./SpanBERT/code")
import argparse
import collections
import json
import logging
import math
import os
import random
import time
import re
import string
from io import open
import numpy as np
import torch
from run_squad import *
from torch.utils.data import DataLoader, TensorDataset
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer,
whitespace_tokenize)
import pickle
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def evaluate(args, eval_dataset, eval_dataloader,
eval_examples, eval_features, na_prob_thresh=1.0, pred_only=False):
all_results = []
filepath = args.bin_dir
all_file = []
files = os.listdir(filepath)
for fi in files:
fi_d = os.path.join(filepath,fi)
if os.path.isdir(fi_d):
all_file.append(fi_d)
bin_path = all_file[0]
for idx, (input_ids, input_mask, segment_ids, example_indices) in enumerate(eval_dataloader):
batch_start_logits = np.fromfile('{}/{}_0.bin'.format(bin_path, idx), dtype='float32')
batch_end_logits = np.fromfile('{}/{}_1.bin'.format(bin_path, idx), dtype='float32')
batch_start_logits = torch.from_numpy(batch_start_logits)
batch_end_logits = torch.from_numpy(batch_end_logits)
batch_start_logits = torch.reshape(batch_start_logits, (-1, 512))
batch_end_logits = torch.reshape(batch_end_logits, (-1, 512))
for i, example_index in enumerate(example_indices):
start_logits = batch_start_logits[i].detach().cpu().tolist()
end_logits = batch_end_logits[i].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
preds, nbest_preds, na_probs = \
make_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length,
args.do_lower_case, args.verbose_logging,
args.version_2_with_negative)
if pred_only:
if args.version_2_with_negative:
for k in preds:
if na_probs[k] > na_prob_thresh:
preds[k] = ''
return {}, preds, nbest_preds
if args.version_2_with_negative:
qid_to_has_ans = make_qid_to_has_ans(eval_dataset)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(eval_dataset, preds)
exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, na_prob_thresh)
f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, na_prob_thresh)
result = make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
merge_eval(result, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
merge_eval(result, no_ans_eval, 'NoAns')
find_all_best_thresh(result, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
for k in preds:
if na_probs[k] > result['best_f1_thresh']:
preds[k] = ''
else:
exact_raw, f1_raw = get_raw_scores(eval_dataset, preds)
result = make_eval_dict(exact_raw, f1_raw)
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
return result, preds, nbest_preds
def main(args):
n_gpu = 1
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(
args.model, do_lower_case=args.do_lower_case)
if not args.eval_test:
with open(args.dev_file) as f:
dataset_json = json.load(f)
eval_dataset = dataset_json['data']
eval_examples = read_squad_examples(
input_file=args.dev_file, is_training=False,
version_2_with_negative=args.version_2_with_negative)
start_time = time.time()
if os.path.exists(args.data_file):
with open(args.data_file, "rb") as handle:
eval_features = pickle.load(handle)
else:
eval_features = convert_examples_to_features(
examples=eval_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=False)
with open(args.data_file, "wb") as handle:
pickle.dump(eval_features, handle)
print("after load eval_feartures: ", time.time() - start_time)
logger.info("***** Dev *****")
logger.info(" Num orig examples = %d", len(eval_examples))
logger.info(" Num split examples = %d", len(eval_features))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size)
if args.do_eval:
na_prob_thresh = 1.0
if args.version_2_with_negative:
eval_result_file = os.path.join(args.output_dir, "eval_results.txt")
if os.path.isfile(eval_result_file):
with open(eval_result_file) as f:
for line in f.readlines():
if line.startswith('best_f1_thresh'):
na_prob_thresh = float(line.strip().split()[-1])
logger.info("na_prob_thresh = %.6f" % na_prob_thresh)
result, preds, _ = \
evaluate(args, eval_dataset,
eval_dataloader, eval_examples, eval_features,
na_prob_thresh=na_prob_thresh,
pred_only=args.eval_test)
with open(os.path.join(args.output_dir, "predictions.json"), "w") as writer:
writer.write(json.dumps(preds, indent=4) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=None, type=str, required=True)
parser.add_argument("--output_dir", default='./', type=str,
help="The output directory where the model checkpoints and predictions will be written.")
parser.add_argument("--bin_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
parser.add_argument("--dev_file", default=None, type=str,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, "
"how much stride to take between chunks.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--eval_test", action='store_true', help='Wehther to run eval on the test set.')
parser.add_argument("--eval_batch_size", default=1, type=int, help="Total batch size for predictions.")
parser.add_argument("--eval_metric", default='f1', type=str)
parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json "
"output file.")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. "
"This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--verbose_logging", action='store_true',
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--version_2_with_negative', action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument("--data_file", default="data.pkl"
, type=str, help="filename to save data")
args = parser.parse_args()
main(args)