#-*-coding:utf-8-*-
"""
Copyright 2018 The Google AI Language Team Authors.
BASED ON Google_BERT.
@Author:zhoukaiyin
Adjust code for chinese ner
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
from bert import modeling
from bert import optimization
from bert import tokenization
import tensorflow as tf
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.ops import math_ops
from train import tf_metrics
import pickle
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
"data_dir", None,
"The input datadir.",
)
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model."
)
flags.DEFINE_string(
"task_name", "NER", "The name of the task to train."
)
flags.DEFINE_bool("crf", False, "use crf!")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written."
)
## Other parameters
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model)."
)
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text."
)
flags.DEFINE_integer(
"max_seq_length", 256,
"The maximum total input sequence length after WordPiece tokenization."
)
flags.DEFINE_bool(
"do_train", True,
"Whether to run training."
)
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on the test set.")
flags.DEFINE_bool('do_export', False, 'Export model ')
flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("num_train_epochs", 4.0, "Total number of training epochs to perform.")
flags.DEFINE_float(
"warmup_proportion", 0.1,
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
flags.DEFINE_integer("save_checkpoints_steps", 1000,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text = text
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_ids, ):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_ids = label_ids
# self.label_mask = label_mask
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_data(cls, input_file):
"""Reads a BIO data."""
with open(input_file) as f:
lines = []
words = []
labels = []
for line in f:
contends = line.strip()
word = line.strip().split(' ')[0]
label = line.strip().split(' ')[-1]
if contends.startswith("-DOCSTART-"):
words.append('')
continue
# if len(contends) == 0 and words[-1] == '。':
if len(contends) == 0:
l = ' '.join([label for label in labels if len(label) > 0])
w = ' '.join([word for word in words if len(word) > 0])
lines.append([l, w])
words = []
labels = []
continue
words.append(word)
labels.append(label)
return lines
class NerProcessor(DataProcessor):
def get_train_examples(self, data_dir):
return self._create_example(
self._read_data(os.path.join(data_dir, "train.txt")), "train"
)
def get_dev_examples(self, data_dir):
return self._create_example(
self._read_data(os.path.join(data_dir, "dev.txt")), "dev"
)
def get_test_examples(self, data_dir):
return self._create_example(
self._read_data(os.path.join(data_dir, "test.txt")), "test")
def get_labels(self):
# prevent potential bug for chinese text mixed with english text
# return ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]","[SEP]"]
# labels = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X","[CLS]","[SEP]"]
labels = ['O', "B-COMAPNY", "I-COMAPNY", "B-REAL", "I-REAL", "B-AMOUT", "I-AMOUT", "[CLS]", "[SEP]"]
return labels
def _create_example(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(InputExample(guid=guid, text=text, label=label))
return examples
def write_tokens(tokens, mode):
if mode == "test":
path = os.path.join(FLAGS.output_dir, "token_" + mode + ".txt")
wf = open(path, 'a')
for token in tokens:
if token != "**NULL**":
wf.write(token + '\n')
wf.close()
def convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, mode):
textlist = example.text.split(' ')
labellist = example.label.split(' ')
tokens = []
labels = []
# print(textlist)
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word)
# print(token)
tokens.extend(token)
label_1 = labellist[i]
# print(label_1)
for m in range(len(token)):
if m == 0:
labels.append(label_1)
else:
labels.append("X")
# print(tokens, labels)
# tokens = tokenizer.tokenize(example.text)
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
# append("O") or append("[CLS]") not sure!
label_ids.append(label_map["[CLS]"])
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
label_ids.append(label_map[labels[i]])
ntokens.append("[SEP]")
segment_ids.append(0)
# append("O") or append("[SEP]") not sure!
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
# label_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
# we don't concerned about it!
label_ids.append(0)
ntokens.append("**NULL**")
# label_mask.append(0)
# print(len(input_ids))
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
# assert len(label_mask) == max_seq_length
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % (example.guid))
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))
# tf.logging.info("label_mask: %s" % " ".join([str(x) for x in label_mask]))
feature = InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_ids=label_ids,
# label_mask = label_mask
)
write_tokens(ntokens, mode)
return feature
def filed_based_convert_examples_to_features(
examples, label_list, max_seq_length, tokenizer, output_file, mode=None
):
if os.path.exists(output_file):
return
label_map = {}
for (i, label) in enumerate(label_list, 1):
label_map[label] = i
with open('./output/label2id.pkl', 'wb') as w:
pickle.dump(label_map, w)
writer = tf.python_io.TFRecordWriter(output_file)
for (ex_index, example) in enumerate(examples):
if ex_index % 5000 == 0:
tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
feature = convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, mode)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(feature.input_ids)
features["input_mask"] = create_int_feature(feature.input_mask)
features["segment_ids"] = create_int_feature(feature.segment_ids)
features["label_ids"] = create_int_feature(feature.label_ids)
# features["label_mask"] = create_int_feature(feature.label_mask)
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString())
def serving_input_receiver_fn():
input_str = tf.placeholder(tf.string, name='inputs')
seq_length = 256
name_to_features = {
"input_ids": tf.VarLenFeature(tf.int64),
"input_mask": tf.VarLenFeature(tf.int64),
"segment_ids": tf.VarLenFeature(tf.int64),
# "label_ids": tf.FixedLenFeature([seq_length], tf.int64),
}
def _decode_record(record, name_to_features):
example = tf.parse_single_example(record, name_to_features)
for name in list(example.keys()):
t = example[name]
t = tf.sparse_tensor_to_dense(t, default_value=0)
t = tf.expand_dims(t, axis=0)
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
features = _decode_record(input_str, name_to_features)
receiver_tensors = {"inputs": input_str}
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
name_to_features = {
"input_ids": tf.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.FixedLenFeature([seq_length], tf.int64),
"segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.FixedLenFeature([seq_length], tf.int64),
# "label_ids":tf.VarLenFeature(tf.int64),
# "label_mask": tf.FixedLenFeature([seq_length], tf.int64),
}
def _decode_record(record, name_to_features):
example = tf.parse_single_example(record, name_to_features)
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
example[name] = t
return example
def input_fn(params):
batch_size = params["batch_size"]
d = tf.data.TFRecordDataset(input_file)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.apply(tf.contrib.data.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
drop_remainder=drop_remainder
))
return d
return input_fn
def create_model(bert_config, is_training, input_ids, input_mask,
segment_ids, num_labels, use_one_hot_embeddings):
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings
)
output_layer = model.get_sequence_output()
hidden_size = output_layer.shape[-1].value
output_weight = tf.get_variable(
"output_weights", [num_labels, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02)
)
output_bias = tf.get_variable(
"output_bias", [num_labels], initializer=tf.zeros_initializer()
)
if is_training:
output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
output_layer = tf.reshape(output_layer, [-1, hidden_size])
logits = tf.matmul(output_layer, output_weight, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
print("logits====>", logits)
return logits
##########################################################################
def crf_loss(logits, labels, mask, num_labels, mask2len):
"""
:param logits:
:param labels:
:param mask2len:each sample's length
:return:
"""
# TODO
with tf.variable_scope("crf_loss"):
trans = tf.get_variable(
"transition",
shape=[num_labels, num_labels],
initializer=tf.contrib.layers.xavier_initializer()
)
log_likelihood, transition = tf.contrib.crf.crf_log_likelihood(logits, labels, transition_params=trans,
sequence_lengths=mask2len)
loss = tf.math.reduce_mean(-log_likelihood)
return loss
def softmax_loss(logits, labels, num_labels, mask):
print("logits ===>", logits)
print("labels===>", labels)
logits = tf.reshape(logits, [-1, num_labels])
labels = tf.reshape(labels, [-1])
mask = tf.cast(mask, dtype=tf.float32)
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=one_hot_labels)
loss *= tf.reshape(mask, [-1])
loss = tf.reduce_sum(loss)
total_size = tf.reduce_sum(mask)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
loss /= total_size
return loss
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
def model_fn(features, mode, params):
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
# label_mask = features["label_mask"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
if mode != tf.estimator.ModeKeys.PREDICT:
label_ids = features["label_ids"]
logits = create_model(
bert_config, is_training, input_ids, input_mask, segment_ids,
num_labels, use_one_hot_embeddings)
if mode != tf.estimator.ModeKeys.PREDICT:
logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, num_labels])
else:
shapes = input_ids.get_shape().as_list()
print("input shapes:", shapes)
logits = tf.reshape(logits, [1, -1, num_labels])
tvars = tf.trainable_variables()
scaffold_fn = None
if init_checkpoint:
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = None
if mode != tf.estimator.ModeKeys.PREDICT:
if FLAGS.crf:
mask2len = tf.reduce_sum(input_mask, axis=1)
loss = crf_loss(logits, label_ids, input_mask, num_labels, mask2len)
else:
loss = softmax_loss(logits, label_ids, num_labels, input_mask)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = optimization.create_optimizer(
loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
elif mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(label_ids, logits):
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
precision = tf_metrics.precision(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro")
recall = tf_metrics.recall(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro")
f = tf_metrics.f1(label_ids, predictions, num_labels, [2, 3, 4, 5, 6, 7], average="macro")
return {
"eval_precision": precision,
"eval_recall": recall,
"eval_f": f,
# "eval_loss": loss,
}
eval_metrics = (metric_fn, [label_ids, logits])
# eval_metrics = (metric_fn, [label_ids, logits])
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn
)
return output_spec
return model_fn
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"ner": NerProcessor
}
if not FLAGS.do_train and not FLAGS.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, bert_config.max_position_embeddings))
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
master=FLAGS.master,
model_dir=FLAGS.output_dir,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_tpu_cores,
per_host_input_for_training=is_per_host))
train_examples = None
num_train_steps = None
num_warmup_steps = None
if FLAGS.do_train:
train_examples = processor.get_train_examples(FLAGS.data_dir)
num_train_steps = int(
len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
model_fn = model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list) + 1,
init_checkpoint=FLAGS.init_checkpoint,
learning_rate=FLAGS.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu)
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
if FLAGS.do_train:
train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
filed_based_convert_examples_to_features(
train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
tf.logging.info("***** Running training *****")
tf.logging.info(" Num examples = %d", len(train_examples))
tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
tf.logging.info(" Num steps = %d", num_train_steps)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=FLAGS.max_seq_length,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
if FLAGS.do_eval:
eval_examples = processor.get_dev_examples(FLAGS.data_dir)
eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
filed_based_convert_examples_to_features(
eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
tf.logging.info("***** Running evaluation *****")
tf.logging.info(" Num examples = %d", len(eval_examples))
tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
eval_steps = None
if FLAGS.use_tpu:
eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
eval_drop_remainder = True if FLAGS.use_tpu else False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if FLAGS.do_predict:
token_path = os.path.join(FLAGS.output_dir, "token_test.txt")
with open('./output/label2id.pkl', 'rb') as rf:
label2id = pickle.load(rf)
id2label = {value: key for key, value in label2id.items()}
if os.path.exists(token_path):
os.remove(token_path)
predict_examples = processor.get_test_examples(FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
filed_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file, mode="test")
tf.logging.info("***** Running prediction*****")
tf.logging.info(" Num examples = %d", len(predict_examples))
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
if FLAGS.use_tpu:
# Warning: According to tpu_estimator.py Prediction on TPU is an
# experimental feature and hence not supported here
raise ValueError("Prediction in TPU not supported")
predict_drop_remainder = True if FLAGS.use_tpu else False
predict_input_fn = file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")
with open(output_predict_file, 'w') as writer:
for prediction in result:
output_line = "\n".join(id2label[id] for id in prediction if id != 0) + "\n"
writer.write(output_line)
if FLAGS.do_export and FLAGS.do_predict:
estimator.export_saved_model('./export_models', serving_input_receiver_fn)
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
tf.app.run()