import argparse
import glob
import os
import sys
import time
import torch
import numpy as np
sys.path.append('./ACE/')
from flair.utils.from_params import Params
from flair.config_parser import ConfigParser, log
from flair.trainers import ReinforcementTrainer
from flair.custom_data_loader import ColumnDataLoader
from flair.training_utils import Metric, Result
longest_token_sequence_in_batch = 124
def get_trainer(config_path):
config = Params.from_file(config_path)
config = ConfigParser(config, all=False, zero_shot=False, other_shot=False, predict=False)
student = config.create_student(nocrf=False)
corpus = config.corpus
trainer = ReinforcementTrainer(student, None, corpus, config=config.config, **config.config["ReinforcementTrainer"],
is_test=True)
base_path = config.get_target_path
trainer.embeddings_storage_mode = 'cpu'
trainer.model.eval()
trainer.model.to('cpu')
trainer.model = trainer.model.load(base_path / "best-model.pt", device='cpu')
training_state = torch.load(base_path / 'training_state.pt', map_location=torch.device('cpu'))
trainer.best_action = training_state['best_action']
trainer.model.selection = trainer.best_action
for embedding in trainer.model.embeddings.embeddings:
if hasattr(embedding, 'tokenizer') and embedding.tokenizer is None:
from transformers import AutoTokenizer
name = embedding.name
if '_v2doc' in name:
name = name.replace('_v2doc', '')
if '_extdoc' in name:
name = name.replace('_extdoc', '')
embedding.tokenizer = AutoTokenizer.from_pretrained(name, do_lower_case=True)
if hasattr(embedding, 'model') and hasattr(embedding.model, 'encoder') and not hasattr(
embedding.model.encoder,
'config'):
embedding.model.encoder.config = embedding.model.config
return trainer
def get_loader(trainer):
loader = ColumnDataLoader(list(trainer.corpus.test), int(args.batch_size),
use_bert=trainer.use_bert,
tokenizer=trainer.bert_tokenizer, model=trainer.model,
sentence_level_batch=trainer.sentence_level_batch, sort_data=True)
loader.assign_tags(trainer.model.tag_type, trainer.model.tag_dictionary)
return loader
def run_postprocess(args):
bin_file_num = len(glob.glob(os.path.join(args.bin_file_path, "*.bin")))
if os.path.exists(args.res_file_path) == False:
os.mkdir(args.res_file_path)
res_file_path = os.path.join(args.res_file_path, "accuracy.txt")
if bin_file_num >= 1:
log.info(f"post process {bin_file_num} file!")
else:
log.info("please input right path!")
exit()
trainer = get_trainer(args.config)
loader = get_loader(trainer)
log.info("---------------ace model starts postprocessing-----------------")
metric = Metric("Evaluation")
for i in range(bin_file_num):
file_path = os.path.join(args.bin_file_path, "in_" + str(i) + "_0.bin")
if not os.path.exists(file_path):
continue
features = np.fromfile(file_path, dtype=np.float32).reshape(
[int(args.batch_size), longest_token_sequence_in_batch, 20])
if i == len(loader)-1:
features = features[:len(loader[i]), :, :]
tags, _ = trainer.model._obtain_labels(torch.tensor(features), loader[i])
for (sentence, sent_tags) in zip(loader[i], tags):
for (token, tag) in zip(sentence.tokens, sent_tags):
token: Token = token
token.add_tag_label("predicted", tag)
for sentence in loader[i]:
gold_tags = [
(tag.tag, str(tag)) for tag in sentence.get_spans("ner")
]
predicted_tags = [
(tag.tag, str(tag)) for tag in sentence.get_spans("predicted")
]
for tag, prediction in predicted_tags:
if (tag, prediction) in gold_tags:
metric.add_tp(tag)
else:
metric.add_fp(tag)
for tag, gold in gold_tags:
if (tag, gold) not in predicted_tags:
metric.add_fn(tag)
else:
metric.add_tn(tag)
detailed_result = (
f"\nMICRO_AVG: acc {metric.micro_avg_accuracy()} - f1-score {metric.micro_avg_f_score()}"
f"\nMACRO_AVG: acc {metric.macro_avg_accuracy()} - f1-score {metric.macro_avg_f_score()}"
)
for class_name in metric.get_classes():
detailed_result += (
f"\n{class_name:<10} tp: {metric.get_tp(class_name)} - fp: {metric.get_fp(class_name)} - "
f"fn: {metric.get_fn(class_name)} - tn: {metric.get_tn(class_name)} - precision: "
f"{metric.precision(class_name):.4f} - recall: {metric.recall(class_name):.4f} - "
f"accuracy: {metric.accuracy(class_name):.4f} - f1-score: "
f"{metric.f_score(class_name):.4f}"
)
result = Result(
main_score=metric.micro_avg_f_score(),
log_line=f"{metric.precision()}\t{metric.recall()}\t{metric.micro_avg_f_score()}",
log_header="PRECISION\tRECALL\tF1",
detailed_results=detailed_result,
)
log.info(result.log_line)
log.info(result.detailed_results)
with open(res_file_path, "w", encoding="utf-8") as file:
file.write(result.detailed_results)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./config/')
parser.add_argument('--bin_file_path', default='./result/')
parser.add_argument('--res_file_path', default='./res_data/')
parser.add_argument('--batch_size', default=1)
args = parser.parse_args()
start = time.time()
run_postprocess(args)
elapsed = (time.time() - start)
print("Time used:", elapsed, "s")