from functools import wraps
import torch
def _get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
create_attention_mask: bool,
):
"""Build masks and position id for left to right model.
Args:
data (torch.Tensor): The data tenor that holds the tokens from the dataset
eod_token (int): ID of the token to that is considered the EOD
reset_position_ids (bool): Switch to reset the document position ID's
reset_attention_mask (bool): Switch to reset the attention mask
eod_mask_loss (bool): Switch to enable the EOD mask loss
create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
Returns:
torch.Tensor: Attention mask needed to be used for Attention
torch.Tensor: The mask used for loss value during training
torch.Tensor: The position ID's of the token
"""
seq_length = data.numel()
if create_attention_mask:
attention_mask = torch.tril(
torch.ones((seq_length, seq_length), device=data.device)
).unsqueeze(0)
else:
attention_mask = None
loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
if reset_position_ids:
position_ids = position_ids.clone()
if reset_attention_mask:
eod_index = position_ids[data == eod_token]
if reset_position_ids:
eod_index = eod_index.clone()
prev_index = 0
for j in range(eod_index.numel()):
i = eod_index[j]
if reset_attention_mask and attention_mask is not None:
attention_mask[0, (i + 1):, : (i + 1)] = 0
if reset_position_ids:
position_ids[(i + 1):] -= i + 1 - prev_index
prev_index = i + 1
if attention_mask is not None:
attention_mask = attention_mask < 0.5
seq_length_tensor = torch.tensor([seq_length])
actual_seq_len = torch.cat([eod_index + 1, seq_length_tensor])
return attention_mask, loss_mask, (position_ids, actual_seq_len)
def collate_wrapper(fn):
@wraps(fn)
def wrapper(samples):
actual_seq_len = [elem['position_ids'][1] for elem in samples]
samples = [{key: val if key != 'position_ids' else val[0] for key, val in elem.items()} for elem in samples]
batch = fn(samples)
seq_len = actual_seq_len[0][-1]
actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)]
batch['actual_seq_len'] = torch.cat(actual_seq_len)
return batch
return wrapper