import contextlib
import logging
import os
from collections import OrderedDict
from argparse import ArgumentError
import torch
from fairseq import metrics, options, utils
from fairseq.data import (
Dictionary,
LanguagePairDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
)
from fairseq.models import FairseqMultiModel
from fairseq.tasks.translation import load_langpair_dataset
from . import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
def _lang_token(lang: str):
return "__{}__".format(lang)
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang)
return idx
@register_task("multilingual_translation")
class MultilingualTranslationTask(LegacyFairseqTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, which indicates the inference langauge direction.
`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to
the same value as training.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
try:
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
except ArgumentError:
pass
parser.add_argument('--upsample-primary', default=1, type=int,
help='amount to upsample primary dataset')
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
metavar='SRCTGT',
help='replace beginning-of-sentence in source sentence with source or target '
'language token. (src/tgt)')
parser.add_argument('--decoder-langtok', action='store_true',
help='replace beginning-of-sentence in target sentence with target language token')
def __init__(self, args, dicts, training):
super().__init__(args)
self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
self.langs = list(dicts.keys())
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = cls.prepare(args, **kwargs)
return cls(args, dicts, training)
@classmethod
def update_args(cls, args):
args.left_pad_source = utils.eval_bool(args.left_pad_source)
args.left_pad_target = utils.eval_bool(args.left_pad_target)
if args.lang_pairs is None:
raise ValueError(
"--lang-pairs is required. List all the language pairs in the training objective."
)
if isinstance(args.lang_pairs, str):
args.lang_pairs = args.lang_pairs.split(",")
@classmethod
def prepare(cls, args, **kargs):
cls.update_args(args)
sorted_langs = sorted(
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
)
if args.source_lang is not None or args.target_lang is not None:
training = False
else:
training = True
dicts = OrderedDict()
for lang in sorted_langs:
paths = utils.split_paths(args.data)
assert len(paths) > 0
dicts[lang] = cls.load_dictionary(
os.path.join(paths[0], "dict.{}.txt".format(lang))
)
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
if args.encoder_langtok is not None or args.decoder_langtok:
for lang_to_add in sorted_langs:
dicts[lang].add_symbol(_lang_token(lang_to_add))
logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
return dicts, training
def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is None:
return self.dicts[src_lang].eos()
if self.args.encoder_langtok == "src":
return _lang_token_index(self.dicts[src_lang], src_lang)
else:
return _lang_token_index(self.dicts[src_lang], tgt_lang)
def get_decoder_langtok(self, tgt_lang):
if not self.args.decoder_langtok:
return self.dicts[tgt_lang].eos()
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
def alter_dataset_langtok(
self,
lang_pair_dataset,
src_eos=None,
src_lang=None,
tgt_eos=None,
tgt_lang=None,
):
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
return lang_pair_dataset
new_src_eos = None
if (
self.args.encoder_langtok is not None
and src_eos is not None
and src_lang is not None
and tgt_lang is not None
):
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
else:
src_eos = None
new_tgt_bos = None
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
new_tgt_bos = self.get_decoder_langtok(tgt_lang)
else:
tgt_eos = None
return TransformEosLangPairDataset(
lang_pair_dataset,
src_eos=src_eos,
new_src_eos=new_src_eos,
tgt_bos=tgt_eos,
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split("-")
langpair_dataset = load_langpair_dataset(
data_path,
split,
src,
self.dicts[src],
tgt,
self.dicts[tgt],
combine=True,
dataset_impl=self.args.dataset_impl,
upsample_primary=self.args.upsample_primary,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
return self.alter_dataset_langtok(
langpair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.lang_pairs
]
),
eval_key=None
if self.training
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the multilingual_translation task is not supported"
)
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
OrderedDict(
[
(
lang_pair,
self.alter_dataset_langtok(
LanguagePairDataset(
src_tokens, src_lengths, self.source_dictionary
),
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_eos=self.target_dictionary.eos(),
tgt_lang=self.args.target_lang,
),
)
]
),
eval_key=lang_pair,
)
def build_model(self, args, from_checkpoint=False):
def check_args():
messages = []
if (
len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs))
!= 0
):
messages.append(
"--lang-pairs should include all the language pairs {}.".format(
args.lang_pairs
)
)
if self.args.encoder_langtok != args.encoder_langtok:
messages.append(
"--encoder-langtok should be {}.".format(args.encoder_langtok)
)
if self.args.decoder_langtok != args.decoder_langtok:
messages.append(
"--decoder-langtok should {} be set.".format(
"" if args.decoder_langtok else "not"
)
)
if len(messages) > 0:
raise ValueError(" ".join(messages))
self.update_args(args)
check_args()
from fairseq import models
model = models.build_model(args, self, from_checkpoint)
if not isinstance(model, FairseqMultiModel):
raise ValueError(
"MultilingualTranslationTask requires a FairseqMultiModel architecture"
)
return model
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
model.train()
from collections import defaultdict
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
curr_lang_pairs = [
lang_pair
for lang_pair in self.model_lang_pairs
if sample[lang_pair] is not None and len(sample[lang_pair]) != 0
]
for idx, lang_pair in enumerate(curr_lang_pairs):
def maybe_no_sync():
if (
self.args.distributed_world_size > 1
and hasattr(model, "no_sync")
and idx < len(curr_lang_pairs) - 1
):
return model.no_sync()
else:
return contextlib.ExitStack()
with maybe_no_sync():
loss, sample_size, logging_output = self._per_lang_pair_train_loss(
lang_pair,
model,
update_num,
criterion,
sample,
optimizer,
ignore_grad,
)
agg_loss += loss.detach().item()
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[k] += logging_output[k]
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
return criterion(model.models[lang_pair], sample[lang_pair])
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
from collections import defaultdict
agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float)
for lang_pair in self.eval_lang_pairs:
if (
lang_pair not in sample
or sample[lang_pair] is None
or len(sample[lang_pair]) == 0
):
continue
loss, sample_size, logging_output = self._per_lang_pair_valid_loss(
lang_pair, model, criterion, sample
)
agg_loss += loss.data.item()
agg_sample_size += sample_size
for k in logging_output:
agg_logging_output[k] += logging_output[k]
agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k]
return agg_loss, agg_sample_size, agg_logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
with torch.no_grad():
if self.args.decoder_langtok:
bos_token = _lang_token_index(
self.target_dictionary, self.args.target_lang
)
else:
bos_token = self.target_dictionary.eos()
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
constraints=constraints,
bos_token=bos_token,
)
def reduce_metrics(self, logging_outputs, criterion):
with metrics.aggregate():
super().reduce_metrics(logging_outputs, criterion)
for k in ["sample_size", "nsentences", "ntokens"]:
metrics.log_scalar(k, sum(l[k] for l in logging_outputs))
@property
def source_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.source_lang]
@property
def target_dictionary(self):
if self.training:
return next(iter(self.dicts.values()))
else:
return self.dicts[self.args.target_lang]
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
return {
"%s-%s"
% (self.args.source_lang, self.args.target_lang): (
self.args.max_source_positions,
self.args.max_target_positions,
)
}
return OrderedDict(
[
(key, (self.args.max_source_positions, self.args.max_target_positions))
for split in self.datasets.keys()
for key in self.datasets[split].datasets.keys()
]
)