import torch
import os
import sys
import argparse
import glob
import numpy as np
from tqdm import tqdm
from pytorch_pretrained_bert import BertConfig
sys.path.append("./BertSum/src")
from models.model_builder import Summarizer
from models import data_loader
from models.data_loader import load_dataset
from models.stats import Statistics
from others.utils import test_rouge
from others.logging import logger
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def rouge_results_to_str(results_dict):
return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format(
results_dict["rouge_1_f_score"] * 100,
results_dict["rouge_2_f_score"] * 100,
results_dict["rouge_l_f_score"] * 100,
results_dict["rouge_1_recall"] * 100,
results_dict["rouge_2_recall"] * 100,
results_dict["rouge_l_recall"] * 100
)
def pre_postprocess(args):
file_names = os.listdir(args.result_dir)
data_info = {}
for file_name in sorted(file_names):
_, data_idx, batch_idx, suffix = file_name.split("_")
if data_idx not in data_info:
data_info[data_idx] = {}
if data_info[data_idx].get(batch_idx) is None:
data_info[data_idx][batch_idx] = []
data_info[data_idx][batch_idx].append(suffix)
summary_outs = []
summary_scores = []
for data_idx in tqdm(data_info):
outputs = []
sent_scores = []
for batch_idx in data_info[data_idx]:
prefix = f'data_{data_idx}_{batch_idx}'
score = np.fromfile(os.path.join(args.result_dir, f'{prefix}_0.bin'),
dtype="float32")
output = np.fromfile(os.path.join(args.result_dir, f'{prefix}_1.bin'),
dtype="bool")
outputs.append(output)
sent_scores.append(score)
summary_outs.append(torch.tensor(outputs))
summary_scores.append(torch.tensor(sent_scores))
test_iter = data_loader.Dataloader(args,
load_dataset(args, 'test', shuffle=False),
args.batch_size, device,
shuffle=False, is_test=True)
for data_idx, batch in enumerate(test_iter):
labels = batch.labels
summary_outs[data_idx] = summary_outs[data_idx][:, 0:labels.shape[1]]
summary_scores[data_idx] = summary_scores[data_idx][:, 0:labels.shape[1]]
return summary_scores, summary_outs
def test(args, step, device, cal_lead=False, cal_oracle=False):
test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False),
args.batch_size, device,
shuffle=False, is_test=True)
def _get_ngrams(n, text):
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def _block_tri(c, p):
tri_c = _get_ngrams(3, c.split())
for s in p:
tri_s = _get_ngrams(3, s.split())
if len(tri_c.intersection(tri_s)) > 0:
return True
return False
stats = Statistics()
can_path = '%s_step%d.candidate' % (args.result_path, step)
gold_path = '%s_step%d.gold' % (args.result_path, step)
sent,output = pre_postprocess(args)
Loss = torch.nn.BCELoss(reduction='none')
sum = 0
k = 0
with open(can_path, 'w') as save_pred:
with open(gold_path, 'w') as save_gold:
with torch.no_grad():
for batch in test_iter:
labels = batch.labels
gold = []
pred = []
if (cal_lead):
selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
elif (cal_oracle):
selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
range(batch.batch_size)]
else:
if labels.shape[0] != sent[k].shape[0]:
k = k + 1
sum = sum + 1
continue
loss = Loss(sent[k], labels.float())
if loss.shape[1] != output[k].shape[1]:
k = k + 1
continue
loss = (loss * output[k].float()).sum()
batch_stats = Statistics(float(loss.cpu().data.numpy()), len(labels))
stats.update(batch_stats)
sent_scores = sent[k] + output[k].float()
sent_scores = sent_scores.cpu().data.numpy()
selected_ids = np.argsort(-sent_scores, 1)
for i, idx in enumerate(selected_ids):
_pred = []
if(len(batch.src_str[i]) == 0):
continue
for j in selected_ids[i][:len(batch.src_str[i])]:
if(j >= len(batch.src_str[i])):
continue
candidate = batch.src_str[i][j].strip()
if(args.block_trigram):
if(not _block_tri(candidate, _pred)):
_pred.append(candidate)
else:
_pred.append(candidate)
if ((not cal_oracle) and (not args.recall_eval) and len(_pred) == 3):
break
_pred = '<q>'.join(_pred)
if(args.recall_eval):
_pred = ' '.join(_pred.split()[:len(batch.tgt_str[i].split())])
pred.append(_pred)
gold.append(batch.tgt_str[i])
for i in range(len(gold)):
save_gold.write(gold[i].strip()+'\n')
for i in range(len(pred)):
save_pred.write(pred[i].strip()+'\n')
k = k + 1
if(step != -1 and args.report_rouge):
print(can_path)
print(gold_path)
rouges = test_rouge(args.temp_dir, can_path, gold_path)
logger.info('Rouges at step %d \n%s' % (step, rouge_results_to_str(rouges)))
return stats
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-encoder", default='classifier', type=str, choices=['classifier', 'transformer', 'rnn', 'baseline'])
parser.add_argument("-mode", default='test', type=str, choices=['train', 'validate', 'test'])
parser.add_argument("-bert_data_path", default='./bert_data')
parser.add_argument("-model_path", default='./models/')
parser.add_argument("-result_path", default='./results/cnndm')
parser.add_argument("-temp_dir", default='./temp')
parser.add_argument("-bert_config_path", default='BertSum/bert_config_uncased_base.json')
parser.add_argument("-batch_size", default=600, type=int)
parser.add_argument("-use_interval", type=str2bool, nargs='?', const=True, default=True)
parser.add_argument("-hidden_size", default=128, type=int)
parser.add_argument("-ff_size", default=512, type=int)
parser.add_argument("-heads", default=4, type=int)
parser.add_argument("-inter_layers", default=2, type=int)
parser.add_argument("-rnn_size", default=512, type=int)
parser.add_argument("-param_init", default=0, type=float)
parser.add_argument("-param_init_glorot", type=str2bool, nargs='?', const=True, default=True)
parser.add_argument("-dropout", default=0.1, type=float)
parser.add_argument("-optim", default='adam', type=str)
parser.add_argument("-lr", default=1, type=float)
parser.add_argument("-beta1", default= 0.9, type=float)
parser.add_argument("-beta2", default=0.999, type=float)
parser.add_argument("-decay_method", default='', type=str)
parser.add_argument("-warmup_steps", default=8000, type=int)
parser.add_argument("-max_grad_norm", default=0, type=float)
parser.add_argument("-save_checkpoint_steps", default=5, type=int)
parser.add_argument("-accum_count", default=1, type=int)
parser.add_argument("-world_size", default=1, type=int)
parser.add_argument("-report_every", default=1, type=int)
parser.add_argument("-train_steps", default=1000, type=int)
parser.add_argument("-recall_eval", type=str2bool, nargs='?', const=True, default=False)
parser.add_argument('-visible_gpus', default='-1', type=str)
parser.add_argument('-gpu_ranks', default='0', type=str)
parser.add_argument('-log_file', default='../logs/cnndm.log')
parser.add_argument('-dataset', default='')
parser.add_argument('-seed', default=666, type=int)
parser.add_argument("-test_all", type=str2bool, nargs='?', const=True, default=False)
parser.add_argument("-test_from", default='')
parser.add_argument("-train_from", default='')
parser.add_argument("-report_rouge", type=str2bool, nargs='?', const=True, default=True)
parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True)
parser.add_argument("-result_dir", default="")
args = parser.parse_args()
device = "cpu" if args.visible_gpus == '-1' else "cuda"
device_id = -1 if device == "cpu" else 0
test(args,0,device)