# Copyright 2024-2026 OpenAccess AI Collective and the LlamaFactory team.



from dataclasses import dataclass

from typing import Any, Literal



import torch

from transformers import DataCollatorForSeq2Seq





def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":

    r"""Expand 2d attention mask to 4d attention mask.



    Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),

    handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.



    e.g.

    ```python

    # input

    [[1, 1, 2, 2, 2, 0]]

    # output

    [

        [

            [

                [o, x, x, x, x, x],

                [o, o, x, x, x, x],

                [x, x, o, x, x, x],

                [x, x, o, o, x, x],

                [x, x, o, o, o, x],

                [x, x, x, x, x, x],

            ]

        ]

    ]

    ```

    where `o` equals to `0.0`, `x` equals to `min_dtype`.

    """

    _, seq_len = attention_mask_with_indices.size()

    min_dtype = torch.finfo(dtype).min

    zero_tensor = torch.tensor(0, dtype=dtype)



    # Create a non-padding mask.

    non_padding_mask = (attention_mask_with_indices != 0).unsqueeze(1).unsqueeze(2)

    # Create indices for comparison.

    indices = attention_mask_with_indices.unsqueeze(1).unsqueeze(2)  # [bsz, 1, 1, seq_len]

    indices_t = attention_mask_with_indices.unsqueeze(1).unsqueeze(3)  # [bsz, 1, seq_len, 1]

    # Create a lower triangular mask.

    tril_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))

    attention_mask_4d = (indices == indices_t) & non_padding_mask & tril_mask

    # Invert the attention mask.

    attention_mask_4d = torch.where(attention_mask_4d, zero_tensor, min_dtype)

    return attention_mask_4d





@dataclass

class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):

    r"""Data collator for 4d attention mask."""



    block_diag_attn: bool = False

    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"

    compute_dtype: "torch.dtype" = torch.float32



    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:

        features = super().__call__(features)



        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":

            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)



        return features