from dataclasses import dataclass
from typing import Any, Literal
import torch
from transformers import DataCollatorForSeq2Seq
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
# input
[[1, 1, 2, 2, 2, 0]]
# output
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, x, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
_, seq_len = attention_mask_with_indices.size()
min_dtype = torch.finfo(dtype).min
zero_tensor = torch.tensor(0, dtype=dtype)
non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)
indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2)
indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3)
tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask
attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)
return attention_mask_4d
@dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
r"""Data collator for 4d attention mask."""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features