"""BERT finetuning runner."""
from __future__ import absolute_import, division, print_function
import argparse
import csv
import logging
import os
import random
import sys
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import CrossEntropyLoss, MSELoss
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
PRED_FILE = "predictions.tsv"
EVAL_FILE = "eval_results.txt"
TEST_FILE = "test_results.txt"
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(cell for cell in line)
lines.append(line)
return lines
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[3]
text_b = line[4]
label = None if set_type == "test" else line[0]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class MnliProcessor(DataProcessor):
"""Processor for the MultiNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir, eval_set="MNLI-m"):
"""See base class."""
if eval_set is None or eval_set == "MNLI-m":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev")
else:
assert eval_set == "MNLI-mm"
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev")
def get_test_examples(self, data_dir, eval_set="MNLI-m"):
"""See base class."""
if eval_set is None or eval_set == "MNLI-m":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")
elif eval_set == "MNLI-mm":
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test")
else:
assert eval_set == "AX"
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "ax.tsv")), "ax")
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
if set_type == "ax":
text_a = line[1]
text_b = line[2]
label = None
else:
text_a = line[8]
text_b = line[9]
label = None if set_type == "test" else line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if (i == 0) and (set_type == "test"):
continue
guid = "%s-%s" % (set_type, i)
if set_type == 'test':
text_a = line[1]
label = None
else:
text_a = line[3]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
if set_type == "test":
text_a = line[1]
label = None
else:
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return [None]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = None if set_type == "test" else line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QqpProcessor(DataProcessor):
"""Processor for the QQP data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
if set_type == "test":
text_a = line[1]
text_b = line[2]
label = None
else:
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QnliProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = None if set_type == "test" else line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = None if set_type == "test" else line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = None if set_type == "test" else line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
PROCESSORS = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
}
OUTPUT_MODES = {
"cola": "classification",
"mnli": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
}
EVAL_METRICS = {
"cola": "mcc",
"mnli": "acc",
"mrpc": "acc_and_f1",
"sst-2": "acc",
"sts-b": "corr",
"qqp": "acc_and_f1",
"qnli": "acc",
"rte": "acc",
"wnli": "acc",
}
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode):
"""Loads a data file into a list of `InputBatch`s."""
if output_mode == 'classification':
label_map = {label: i for i, label in enumerate(label_list)}
features = []
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[:(max_seq_length - 2)]
tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
segment_ids = [0] * len(tokens)
if tokens_b:
tokens += tokens_b + ["[SEP]"]
segment_ids += [1] * (len(tokens_b) + 1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
if example.label is None:
label_id = None
else:
if output_mode == "classification":
label_id = label_map[example.label]
elif output_mode == "regression":
label_id = float(example.label)
else:
raise KeyError(output_mode)
if ex_index < 5:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if example.label is None:
logger.info("label: <UNK>")
else:
logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}
def compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
def evaluate(task_name, model, device, eval_dataloader, eval_label_ids, num_labels):
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
for eval_example in eval_dataloader:
if len(eval_example) == 4:
input_ids, input_mask, segment_ids, label_ids = eval_example
else:
input_ids, input_mask, segment_ids = eval_example
label_ids = None
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None)
if label_ids is not None:
label_ids = label_ids.to(device)
if OUTPUT_MODES[task_name] == "classification":
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif OUTPUT_MODES[task_name] == "regression":
loss_fct = MSELoss()
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
if OUTPUT_MODES[task_name] == "classification":
preds = np.argmax(preds, axis=1)
elif OUTPUT_MODES[task_name] == "regression":
preds = np.squeeze(preds)
if eval_label_ids is not None:
result = compute_metrics(task_name, preds, eval_label_ids.numpy())
result['eval_loss'] = eval_loss
else:
result = {}
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
return preds, result
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}, 16-bits training: {}".format(
device, n_gpu, args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.do_train:
logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), 'w'))
else:
logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w'))
logger.info(args)
task_name = args.task_name.lower()
if task_name not in PROCESSORS:
raise ValueError("Task not found: %s" % (task_name))
processor = PROCESSORS[task_name]()
label_list = processor.get_labels()
id2label = {i: label for i, label in enumerate(label_list)}
num_labels = len(label_list)
eval_metric = EVAL_METRICS[task_name]
tokenizer = BertTokenizer.from_pretrained(args.model, do_lower_case=args.do_lower_case)
if args.do_train or (not args.eval_test):
if task_name == "mnli":
eval_examples = processor.get_dev_examples(args.data_dir, eval_set=args.eval_set)
else:
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, OUTPUT_MODES[task_name])
logger.info("***** Dev *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if OUTPUT_MODES[task_name] == "classification":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
elif OUTPUT_MODES[task_name] == "regression":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
if args.fp16:
all_label_ids = all_label_ids.half()
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size)
eval_label_ids = all_label_ids
if args.do_train:
train_examples = processor.get_train_examples(args.data_dir)
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, OUTPUT_MODES[task_name])
if args.train_mode == 'sorted' or args.train_mode == 'random_sorted':
train_features = sorted(train_features, key=lambda f: np.sum(f.input_mask))
else:
random.shuffle(train_features)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if OUTPUT_MODES[task_name] == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif OUTPUT_MODES[task_name] == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
if args.fp16:
all_label_ids = all_label_ids.half()
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
train_dataloader = DataLoader(train_data, batch_size=args.train_batch_size, drop_last=True)
train_batches = [batch for batch in train_dataloader]
eval_step = max(1, len(train_batches) // args.eval_per_epoch)
num_train_optimization_steps = \
len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
logger.info("***** Training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
best_result = None
lrs = [args.learning_rate] if args.learning_rate else \
[1e-6, 2e-6, 3e-6, 5e-6, 1e-5, 2e-5, 3e-5, 5e-5]
for lr in lrs:
cache_dir = args.cache_dir if args.cache_dir else \
PYTORCH_PRETRAINED_BERT_CACHE
model = BertForSequenceClassification.from_pretrained(
args.model, cache_dir=cache_dir, num_labels=num_labels)
if args.fp16:
model.half()
model.to(device)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer
if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer
if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex"
"to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters,
lr=lr,
bias_correction=False,
max_grad_norm=1.0)
if args.loss_scale == 0:
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
else:
optimizer = BertAdam(optimizer_grouped_parameters,
lr=lr,
warmup=args.warmup_proportion,
t_total=num_train_optimization_steps)
global_step = 0
nb_tr_steps = 0
nb_tr_examples = 0
tr_loss = 0
start_time = time.time()
for epoch in range(int(args.num_train_epochs)):
model.train()
logger.info("Start epoch #{} (lr = {})...".format(epoch, lr))
if args.train_mode == 'random' or args.train_mode == 'random_sorted':
random.shuffle(train_batches)
for step, batch in enumerate(train_batches):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
logits = model(input_ids, segment_ids, input_mask, labels=None)
if OUTPUT_MODES[task_name] == "classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif OUTPUT_MODES[task_name] == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), label_ids.view(-1))
if n_gpu > 1:
loss = loss.mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
lr_this_step = lr * \
warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step()
optimizer.zero_grad()
global_step += 1
if (step + 1) % eval_step == 0:
logger.info('Epoch: {}, Step: {} / {}, used_time = {:.2f}s, loss = {:.6f}'.format(
epoch, step + 1, len(train_dataloader),
time.time() - start_time, tr_loss / nb_tr_steps))
save_model = False
if args.do_eval:
preds, result = evaluate(task_name, model, device,
eval_dataloader, eval_label_ids, num_labels)
model.train()
result['global_step'] = global_step
result['epoch'] = epoch
result['learning_rate'] = lr
result['batch_size'] = args.train_batch_size
logger.info("First 20 predictions:")
for pred, label in zip(preds[:20], eval_label_ids.numpy()[:20]):
if OUTPUT_MODES[task_name] == 'classification':
sign = u'\u2713' if pred == label else u'\u2718'
logger.info("pred = %s, label = %s %s" % (id2label[pred], id2label[label], sign))
else:
logger.info("pred = %.4f, label = %.4f" % (pred, label))
if (best_result is None) or (result[eval_metric] > best_result[eval_metric]):
best_result = result
save_model = True
logger.info("!!! Best dev %s (lr=%s, epoch=%d): %.2f" %
(eval_metric, str(lr), epoch, result[eval_metric] * 100.0))
else:
save_model = True
if save_model:
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
if best_result:
output_eval_file = os.path.join(args.output_dir, EVAL_FILE)
with open(output_eval_file, "w") as writer:
for key in sorted(result.keys()):
writer.write("%s = %s\n" % (key, str(result[key])))
if args.do_eval:
if args.eval_test:
if task_name == "mnli":
eval_examples = processor.get_test_examples(args.data_dir, eval_set=args.eval_set)
else:
eval_examples = processor.get_test_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, OUTPUT_MODES[task_name])
logger.info("***** Test *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
eval_dataloader = DataLoader(eval_data, batch_size=args.eval_batch_size)
eval_label_ids = None
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
if args.fp16:
model.half()
model.to(device)
preds, result = evaluate(task_name, model, device, eval_dataloader, eval_label_ids, num_labels)
pred_file = os.path.join(args.output_dir, PRED_FILE)
with open(pred_file, "w") as f_out:
f_out.write("index\tprediction\n")
for i, pred in enumerate(preds):
if OUTPUT_MODES[task_name] == 'classification':
f_out.write("%d\t%s\n" % (i, id2label[pred]))
else:
f_out.write("%d\t%.6f\n" % (i, pred))
output_eval_file = os.path.join(args.output_dir, TEST_FILE)
with open(output_eval_file, "w") as writer:
for key in sorted(result.keys()):
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--model", default=None, type=str, required=True)
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--eval_per_epoch", default=10, type=int,
help="How many times it evaluates on dev set per epoch")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--train_mode", type=str, default='random_sorted', choices=['random', 'sorted', 'random_sorted'])
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--eval_test", action='store_true', help="Whether to eval on the test set.")
parser.add_argument("--eval_set", type=str, default=None, help="Whether to evalu on the test set.")
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.")
parser.add_argument("--learning_rate", default=None, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
args = parser.parse_args()
main(args)