# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.

from functools import wraps
import torch


def _get_ltor_masks_and_position_ids(
        data: torch.Tensor,
        eod_token: int,
        reset_position_ids: bool,
        reset_attention_mask: bool,
        eod_mask_loss: bool,
        create_attention_mask: bool,
):
    """Build masks and position id for left to right model.

    Args:
        data (torch.Tensor): The data tenor that holds the tokens from the dataset

        eod_token (int): ID of the token to that is considered the EOD

        reset_position_ids (bool): Switch to reset the document position ID's

        reset_attention_mask (bool): Switch to reset the attention mask

        eod_mask_loss (bool): Switch to enable the EOD mask loss

        create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.

    Returns:
        torch.Tensor: Attention mask needed to be used for Attention

        torch.Tensor: The mask used for loss value during training

        torch.Tensor: The position ID's of the token
    """
    seq_length = data.numel()

    if create_attention_mask:
        attention_mask = torch.tril(
            torch.ones((seq_length, seq_length), device=data.device)
        ).unsqueeze(0)
    else:
        attention_mask = None

    # Loss mask.
    loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_attention_mask:
        # Find indices where EOD token is.
        eod_index = position_ids[data == eod_token]
        # Detach indices from positions if going to modify positions.
        if reset_position_ids:
            eod_index = eod_index.clone()

        # Loop through EOD indices:
        prev_index = 0
        for j in range(eod_index.numel()):
            i = eod_index[j]
            # Mask attention loss.
            if reset_attention_mask and attention_mask is not None:
                attention_mask[0, (i + 1):, : (i + 1)] = 0
            # Reset positions.
            if reset_position_ids:
                position_ids[(i + 1):] -= i + 1 - prev_index
                prev_index = i + 1

    if attention_mask is not None:
        # Convert attention mask to binary:
        attention_mask = attention_mask < 0.5

    seq_length_tensor = torch.tensor([seq_length])
    actual_seq_len = torch.cat([eod_index + 1, seq_length_tensor])

    return attention_mask, loss_mask, (position_ids, actual_seq_len)


def collate_wrapper(fn):
    @wraps(fn)
    def wrapper(samples):
        actual_seq_len = [elem['position_ids'][1] for elem in samples]
        samples = [{key: val if key != 'position_ids' else val[0] for key, val in elem.items()} for elem in samples]
        batch = fn(samples)
        seq_len = actual_seq_len[0][-1]
        actual_seq_len = [elem + i * seq_len for i, elem in enumerate(actual_seq_len)]
        batch['actual_seq_len'] = torch.cat(actual_seq_len)
        return batch
    
    return wrapper