from typing import Optional
import torch
from ..utils import register_tensor_cast_op
@register_tensor_cast_op("shift_and_update_input_ids")
def _(
input_ids: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
next_tokens: torch.Tensor,
) -> torch.Tensor:
"""
Creates a new input_ids tensor by shifting each query's tokens to the left
and appending a new token at the end.
Args:
input_ids: A 1D tensor containing the concatenated tokens of all queries.
query_start_loc: A 1D tensor of shape `(batch_size + 1)`, where
`query_start_loc[i]` is the starting index of the i-th
query in `input_ids`. If not set, the input_ids have the
same length indicated by the input_ids shape:
(batch_size, query_length, hidden_size).
next_tokens: A 2D tensor of shape `(batch_size, sequence_length)`. The last
token from each sequence (`next_tokens[:, -1]`) will be used.
Returns:
A new 1D tensor `new_input_ids` with the transformed tokens.
"""
return torch.empty_like(input_ids)