import numpy as np
import torch
import torch.nn.functional as F
def padding(indice, max_length, pad_idx=0):
pad_indice = [
item + [pad_idx] * max(0, max_length - len(item)) for item in indice
]
return torch.tensor(pad_indice)
def bert_sequence_label_gp_collate_fn(batch):
"""
Dynamic padding
"""
def sequence_padding(inputs,
length=None,
value=0,
seq_dims=1,
mode='post'):
""" padding sequence to the same lenght"""
if length is None:
length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
elif not hasattr(length, '__getitem__'):
length = [length]
slices = [np.s_[:length[i]] for i in range(seq_dims)]
slices = tuple(slices) if len(slices) > 1 else slices[0]
pad_width = [(0, 0) for _ in np.shape(inputs[0])]
outputs = []
for x in inputs:
x = x[slices]
for i in range(seq_dims):
if mode == 'post':
pad_width[i] = (0, length[i] - np.shape(x)[i])
elif mode == 'pre':
pad_width[i] = (length[i] - np.shape(x)[i], 0)
else:
raise ValueError(
'"mode" argument must be "post" or "pre".')
x = np.pad(x, pad_width, 'constant', constant_values=value)
outputs.append(x)
return np.array(outputs)
token_ids = [data["input_ids"] for data in batch]
labels = [data["labels"] for data in batch]
token_ids_padded = sequence_padding(token_ids)
labels_padded = sequence_padding(labels, seq_dims=3)
token_ids_padded = torch.from_numpy(token_ids_padded)
labels_padded = torch.from_numpy(labels_padded)
return {"input_ids": token_ids_padded, "labels": labels_padded}
def seq2seq_collate_fn(batch):
token_ids = [data["input_ids"] for data in batch]
max_length = max([len(t) for t in token_ids])
token_type_ids = [data["segment_ids"] for data in batch]
token_ids_padded = padding(token_ids, max_length)
token_type_ids_padded = padding(token_type_ids, max_length)
target_ids_padded = token_ids_padded[:, 1:].contiguous()
data = {
"input_ids": token_ids_padded,
"segment_ids": token_type_ids_padded,
"labels": target_ids_padded
}
flash_atten = all([data["flash_atten"] for data in batch])
if flash_atten:
attention_mask = token_ids_padded > 0
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
data['cu_seqlens'] = cu_seqlens
return data
def bert_cls_collate_fn(batch):
token_ids = [data["input_ids"] for data in batch]
max_length = max([len(t) for t in token_ids])
token_type_ids = [data["segment_ids"] for data in batch]
target_ids = [data["labels"] for data in batch]
target_ids = torch.tensor(target_ids, dtype=torch.long)
token_ids_padded = padding(token_ids, max_length)
token_type_ids_padded = padding(token_type_ids, max_length)
return {
"input_ids": token_ids_padded,
"segment_ids": token_type_ids_padded,
"labels": target_ids
}
def bert_sequence_label_collate_fn(batch):
"""
dynamical padding
"""
token_ids = [data["input_ids"] for data in batch]
max_length = max([len(t) for t in token_ids])
target_ids = [data["labels"] for data in batch]
token_ids_padded = padding(token_ids, max_length)
target_ids_padded = padding(target_ids, max_length)
return {"input_ids": token_ids_padded, "labels": target_ids_padded}