import argparse
import collections
import contextlib
import copy
import importlib
import logging
import os
import sys
import warnings
from itertools import accumulate
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
if TYPE_CHECKING:
from fairseq.modules.multihead_attention import MultiheadAttention
try:
from amp_C import multi_tensor_l2norm
multi_tensor_l2norm_available = True
except ImportError:
multi_tensor_l2norm_available = False
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
logger = logging.getLogger(__name__)
MANIFOLD_PATH_SEP = "|"
class FileContentsAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs is not None:
raise ValueError("nargs not allowed")
super(FileContentsAction, self).__init__(option_strings, dest, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
from fairseq.file_io import PathManager
if PathManager.isfile(values):
with PathManager.open(values) as f:
argument = f.read().strip()
else:
argument = values
setattr(namespace, self.dest, argument)
def split_paths(paths: str, separator=os.pathsep) -> List[str]:
return (
paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP)
)
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
from fairseq import checkpoint_utils
deprecation_warning(
"utils.load_ensemble_for_inference is deprecated. "
"Please use checkpoint_utils.load_model_ensemble instead."
)
return checkpoint_utils.load_model_ensemble(
filenames, arg_overrides=model_arg_overrides, task=task
)
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, collections.OrderedDict):
od = collections.OrderedDict(
(key, _apply(value)) for key, value in x.items()
)
od.__dict__ = x.__dict__
return od
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(sample)
def move_to_cuda(sample, device=None):
device = device or torch.npu.current_device()
def _move_to_cuda(tensor):
return tensor.to(device=torch.device(f"npu:{device}"), non_blocking=True)
return apply_to_sample(_move_to_cuda, sample)
def move_to_cpu(sample):
def _move_to_cpu(tensor):
if tensor.dtype in {torch.bfloat16, torch.float16}:
tensor = tensor.to(dtype=torch.float32)
return tensor.cpu()
return apply_to_sample(_move_to_cpu, sample)
def move_to_tpu(sample):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
def _move_to_tpu(tensor):
return tensor.to(device)
return apply_to_sample(_move_to_tpu, sample)
def get_incremental_state(
module: "MultiheadAttention",
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
return module.get_incremental_state(incremental_state, key)
def set_incremental_state(
module: "MultiheadAttention",
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
result = module.set_incremental_state(incremental_state, key, value)
if result is not None:
incremental_state = result
return incremental_state
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str) and len(replace_unk) > 0:
align_dict = {}
with open(replace_unk, "r") as f:
for line in f:
cols = line.split()
align_dict[cols[0]] = cols[1]
else:
align_dict = {}
return align_dict
def print_embed_overlap(embed_dict, vocab_dict):
embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys)
logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict)))
def parse_embedding(embed_path):
"""Parse embedding text file into a dictionary of word and embedding tensors.
The first line can have vocabulary size and dimension. The following lines
should contain word and embedding separated by spaces.
Example:
2 5
the -0.0230 -0.0264 0.0287 0.0171 0.1403
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
"""
embed_dict = {}
with open(embed_path) as f_embed:
next(f_embed)
for line in f_embed:
pieces = line.rstrip().split(" ")
embed_dict[pieces[0]] = torch.Tensor(
[float(weight) for weight in pieces[1:]]
)
return embed_dict
def load_embedding(embed_dict, vocab, embedding):
for idx in range(len(vocab)):
token = vocab[idx]
if token in embed_dict:
embedding.weight.data[idx] = embed_dict[token]
return embedding
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer
hypo_tokens = tokenizer.tokenize_line(hypo_str)
src_tokens = tokenizer.tokenize_line(src_str) + ["<eos>"]
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
hypo_tokens[i] = align_dict.get(src_token, src_token)
return " ".join(hypo_tokens)
def post_process_prediction(
hypo_tokens,
src_str,
alignment,
align_dict,
tgt_dict,
remove_bpe=None,
extra_symbols_to_ignore=None,
):
hypo_str = tgt_dict.string(
hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore
)
if align_dict is not None:
hypo_str = replace_unk(
hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()
)
if align_dict is not None or remove_bpe is not None:
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
mask = tensor.ne(padding_idx).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)]
def buffered_arange(max):
if not hasattr(buffered_arange, "buf"):
buffered_arange.buf = torch.LongTensor()
if max > buffered_arange.buf.numel():
buffered_arange.buf.resize_(max)
torch.arange(max, out=buffered_arange.buf)
return buffered_arange.buf[:max]
def convert_padding_direction(
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
return src_tokens
max_len = src_tokens.size(1)
buffered = torch.empty(0).long()
if max_len > 0:
torch.arange(max_len, out=buffered)
range = buffered.type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
def item(tensor):
if torch.is_tensor(tensor) and tensor.device.type == "xla":
return tensor.detach()
if hasattr(tensor, "item"):
return tensor.item()
if hasattr(tensor, "__getitem__"):
return tensor[0]
return tensor
def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
per_device_grads = {}
norms = []
for grad in grads:
device = grad.device
cur_device_grads = per_device_grads.get(device)
if cur_device_grads is None:
cur_device_grads = []
per_device_grads[device] = cur_device_grads
cur_device_grads.append(grad)
for device in per_device_grads.keys():
cur_device_grads = per_device_grads[device]
if device.type == "npu":
has_inf = torch.zeros((1, 1), dtype=torch.int, device=device)
with torch.npu.device(device):
norm = multi_tensor_l2norm(
chunk_size, has_inf, [cur_device_grads], False
)
norms.append(norm[0].to(torch.npu.current_device()))
else:
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads]
total_norm = torch.norm(torch.stack(norms))
return total_norm
@torch.no_grad()
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
grads = [
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert")
]
expert_grads = [
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert")
]
if len(grads) == 0:
if len(params) > 0:
return params[0].new_tensor(0.0)
else:
return torch.tensor(0.0)
if len(grads) == 1:
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
else:
if multi_tensor_l2norm_available:
total_norm = multi_tensor_total_norm(grads)
else:
if torch.npu.is_available():
warnings.warn(
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
"you may get better performance by installing NVIDIA's apex library"
)
device = torch.npu.current_device()
elif grads[0].device.type == "xla":
device = grads[0].device
else:
device = torch.device("cpu")
total_norm = torch.norm(
torch.stack(
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
)
)
if aggregate_norm_fn is not None:
total_norm = aggregate_norm_fn(total_norm)
if max_norm > 0:
max_norm = float(max_norm)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads + expert_grads:
g.mul_(clip_coef)
return total_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
def _match_types(arg1, arg2):
"""Convert the numerical argument to the same type as the other argument"""
def upgrade(arg_number, arg_structure):
if isinstance(arg_structure, tuple):
return tuple([arg_number] * len(arg_structure))
elif isinstance(arg_structure, dict):
arg = copy.deepcopy(arg_structure)
for k in arg:
arg[k] = upgrade(arg_number, arg_structure[k])
return arg
else:
return arg_number
if isinstance(arg1, float) or isinstance(arg1, int):
return upgrade(arg1, arg2), arg2
elif isinstance(arg2, float) or isinstance(arg2, int):
return arg1, upgrade(arg2, arg1)
return arg1, arg2
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
def map_value_update(d1, d2):
updated_value = copy.deepcopy(d1)
for key in d2:
if key not in updated_value:
updated_value[key] = d2[key]
else:
updated_value[key] = min(d1[key], d2[key])
return updated_value
def nullsafe_min(l):
minim = None
for item in l:
if minim is None:
minim = item
elif item is not None and item < minim:
minim = item
return minim
max_positions = None
for arg in args:
if max_positions is None:
max_positions = arg
elif arg is not None:
max_positions, arg = _match_types(max_positions, arg)
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
elif isinstance(arg, dict):
max_positions = map_value_update(max_positions, arg)
else:
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg)))
return max_positions
def import_user_module(args):
module_path = getattr(args, "user_dir", None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
if not os.path.exists(module_path) and not os.path.isfile(
os.path.dirname(module_path)
):
fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir)
if os.path.exists(fairseq_rel_path):
module_path = fairseq_rel_path
else:
fairseq_rel_path = os.path.join(
os.path.dirname(__file__), "..", args.user_dir
)
if os.path.exists(fairseq_rel_path):
module_path = fairseq_rel_path
else:
raise FileNotFoundError(module_path)
import_user_module.memo = getattr(import_user_module, "memo", set())
if module_path not in import_user_module.memo:
import_user_module.memo.add(module_path)
module_parent, module_name = os.path.split(module_path)
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
tasks_path = os.path.join(module_path, "tasks")
if os.path.exists(tasks_path):
from fairseq.tasks import import_tasks
import_tasks(tasks_path, f"{module_name}.tasks")
models_path = os.path.join(module_path, "models")
if os.path.exists(models_path):
from fairseq.models import import_models
import_models(models_path, f"{module_name}.models")
elif module_path in sys.modules[module_name].__path__:
logger.info(f"--user-dir={module_path} has already been imported.")
else:
raise ImportError(
"Failed to import --user-dir={} because the corresponding module name "
"({}) is not globally unique. Please rename the directory to "
"something unique and try again.".format(module_path, module_name)
)
def softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def log_softmax(x, dim: int, onnx_trace: bool = False):
if onnx_trace:
return F.log_softmax(x.float(), dim=dim)
else:
return F.log_softmax(x, dim=dim, dtype=torch.float32)
def get_perplexity(loss, round=2, base=2):
from fairseq.logging.meters import safe_round
if loss is None:
return 0.0
try:
return safe_round(base**loss, round)
except OverflowError:
return float("inf")
def deprecation_warning(message, stacklevel=3):
warnings.warn(message, stacklevel=stacklevel)
def relu_squared(x: torch.Tensor):
return F.relu(x).pow(2)
def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`"""
from fairseq.modules import gelu, gelu_accurate
if activation == "relu":
return F.relu
elif activation == "relu_squared":
return relu_squared
elif activation == "gelu":
return gelu
elif activation == "gelu_fast":
deprecation_warning(
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
)
return gelu_accurate
elif activation == "gelu_accurate":
return gelu_accurate
elif activation == "tanh":
return torch.tanh
elif activation == "linear":
return lambda x: x
elif activation == "swish":
return torch.nn.SiLU
else:
raise RuntimeError("--activation-fn {} not supported".format(activation))
def get_available_activation_fns() -> List:
return [
"relu",
"gelu",
"gelu_fast",
"gelu_accurate",
"tanh",
"linear",
]
@contextlib.contextmanager
def model_eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
def has_parameters(module):
try:
next(module.parameters())
return True
except StopIteration:
return False
def get_rng_state():
state = {"torch_rng_state": torch.get_rng_state()}
if xm is not None:
state["xla_rng_state"] = xm.get_rng_state()
if torch.npu.is_available():
state["cuda_rng_state"] = torch.npu.get_rng_state()
return state
def set_rng_state(state):
torch.set_rng_state(state["torch_rng_state"])
if xm is not None:
xm.set_rng_state(state["xla_rng_state"])
if torch.npu.is_available():
torch.npu.set_rng_state(state["cuda_rng_state"])
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
torch.manual_seed(seed)
if torch.npu.is_available():
torch.npu.manual_seed(seed)
def __enter__(self):
return self
def __exit__(self, *exc):
pass
def parse_alignment(line):
"""
Parses a single line from the alingment file.
Args:
line (str): String containing the alignment of the format:
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> ..
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed.
Returns:
torch.IntTensor: packed alignments of shape (2 * m).
"""
alignments = line.strip().split()
parsed_alignment = torch.IntTensor(2 * len(alignments))
for idx, alignment in enumerate(alignments):
src_idx, tgt_idx = alignment.split("-")
parsed_alignment[2 * idx] = int(src_idx)
parsed_alignment[2 * idx + 1] = int(tgt_idx)
return parsed_alignment
def get_token_to_word_mapping(tokens, exclude_list):
n = len(tokens)
word_start = [int(token not in exclude_list) for token in tokens]
word_idx = list(accumulate(word_start))
token_to_word = {i: word_idx[i] for i in range(n)}
return token_to_word
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = (
((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1)
)
src_invalid = (
((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1)
)
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad])
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad])
alignment = []
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent):
attn_valid = attn[tgt_valid]
attn_valid[:, src_invalid] = float("-inf")
_, src_indices = attn_valid.max(dim=1)
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
alignment.append(
(
src_token_to_word[src_idx.item()] - 1,
tgt_token_to_word[tgt_idx.item()] - 1,
)
)
return alignment
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos):
tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False)
src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1)
alignment = []
if len(tgt_valid) != 0 and len(src_valid) != 0:
attn_valid = attn[tgt_valid, src_valid]
alignment = [
["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid
]
return alignment
def new_arange(x, *size):
"""
Return a Tensor of `size` filled with a range function on the device of x.
If size is empty, using the size of the variable x.
"""
if len(size) == 0:
size = x.size()
return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
def get_tpu_device():
return xm.xla_device()
def tpu_data_loader(itr):
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from fairseq.data import iterators
xm.rendezvous("tpu_data_loader")
xm.mark_step()
device = xm.xla_device()
return iterators.CountingIterator(
pl.ParallelLoader(itr, [device]).per_device_loader(device),
start=getattr(itr, "n", 0),
total=len(itr),
)
def is_xla_tensor(tensor):
return torch.is_tensor(tensor) and tensor.device.type == "npu"
def index_put(tensor, indices, value):
if is_xla_tensor(tensor):
for _ in range(indices.dim(), tensor.dim()):
indices = indices.unsqueeze(-1)
if indices.size(-1) < tensor.size(-1):
indices = indices.expand_as(tensor)
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
else:
tensor[indices] = value
return tensor
def xla_device_to_cpu(dat):
import torch_xla.core.xla_model as xm
return xm._maybe_convert_to_cpu(dat)
class CudaEnvironment(object):
def __init__(self):
cur_device = torch.npu.current_device()
self.name = 'Ascend910'
self.major = 7
self.minor = 0
self.total_memory_in_GB = 31.74853515625
print(f'{self.name} {self.major} {self.minor} {self.total_memory_in_GB}', flush=True)
@staticmethod
def pretty_print_cuda_env_list(cuda_env_list):
"""
Given a list of CudaEnviorments, pretty print them
"""
num_workers = len(cuda_env_list)
center = "CUDA enviroments for all {} workers".format(num_workers)
banner_len = 40 - len(center) // 2
first_line = "*" * banner_len + center + "*" * banner_len
logger.info(first_line)
for r, env in enumerate(cuda_env_list):
logger.info(
"rank {:3d}: ".format(r)
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor)
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB)
+ "name = {:40s}".format(env.name)
)
logger.info(first_line)
def csv_str_list(x):
return x.split(",")
def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(type, x))
except TypeError:
return [type(x)]
def eval_str_dict(x, type=dict):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
return x
def eval_bool(x, default=False):
if x is None:
return default
try:
return bool(eval(x))
except TypeError:
return default
def reset_logging():
root = logging.getLogger()
for handler in root.handlers:
root.removeHandler(handler)
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper())
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
root.addHandler(handler)
def safe_getattr(obj, k, default=None):
"""Returns obj[k] if it exists and is not None, otherwise returns default."""
from omegaconf import OmegaConf
if OmegaConf.is_config(obj):
return obj[k] if k in obj and obj[k] is not None else default
return getattr(obj, k, default)
def safe_hasattr(obj, k):
"""Returns True if the given key exists and is not None."""
return getattr(obj, k, None) is not None