import itertools
from typing import List, Optional, Tuple
from functools import partial
import torch
from torch import Tensor
from megatron.training import get_args
from megatron.core import InferenceParams, mpu
def qwen2vl_get_rope_index(
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [3, 4, 5, 6, 7]
text height position_ids: [3, 4, 5, 6, 7]
text width position_ids: [3, 4, 5, 6, 7]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
args = get_args()
spatial_merge_size = getattr(args.mm.model.image_encoder.vision_encoder, 'spatial_merge_size', 2)
image_token_id = 151655
video_token_id = 151656
vision_start_token_id = getattr(args.mm.model, 'vision_start_token_id', 151652)
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
image_grid_thw[video_index][0],
image_grid_thw[video_index][1],
image_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_llm_pos_ids_for_vision(
start_idx: int,
vision_idx: int,
spatial_merge_size: int,
t_index: List[int],
grid_hs: List[int],
grid_ws: List[int],
):
llm_pos_ids_list = []
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten()
t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long()
_llm_pos_ids = torch.stack([t_index, h_index, w_index])
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
return llm_pos_ids
def get_chunked_index(
token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int
) -> List[Tuple[int, int]]:
"""
Splits token index list into chunks based on token value ranges.
Given a list of token indices, returns a list of (start, end) index tuples representing
slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`.
For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that:
- the first chunk contains token values < 1000,
- the second chunk contains values >= 1000 and < 2000, and so on.
Parameters:
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
token index values.
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
remove_index (`int`) An index id to subtract from `token_indices` before chunking
Returns:
`List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive)
and end (exclusive) indices of a chunk in `token_indices`.
"""
def _iter():
i, start_idx = 0, 0
current_chunk = 1
while i < len(token_indices):
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
yield (start_idx, i)
start_idx = i
current_chunk += 1
i += 1
yield (start_idx, len(token_indices))
return list(_iter())
def qwen2_5_omni_get_rope_index(
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
use_audio_in_video (`bool`, *optional*):
If set to `True`, use the audio in video.
audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*):
The length of feature shape of each audio in LLM.
second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
args = get_args()
spatial_merge_size = getattr(args.mm.model.image_encoder.vision_encoder, 'spatial_merge_size', 2)
image_token_id = 151655
video_token_id = 151656
audio_token_id = 151646
vision_start_token_id = getattr(args.mm.model, 'vision_start_token_id', 151652)
audio_start_token_id = getattr(args.mm.model, 'audio_start_token_id', 151647)
position_id_per_seconds = 25
seconds_per_chunk = 2
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
multimodal_nums = (
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio = input_tokens.index(audio_token_id, st)
else:
ed_audio = len(input_tokens) + 1
min_ed = min(ed_image, ed_video, ed_audio)
if min_ed == ed_audio:
text_len = min_ed - st - 1
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif min_ed == ed_image:
text_len = min_ed - st - 1
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif min_ed == ed_video and not use_audio_in_video:
text_len = min_ed - st - 1
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = image_grid_thw[video_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * second_per_grid[video_idx].cpu().float() * position_id_per_seconds
).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = image_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif min_ed == ed_video and use_audio_in_video:
text_len = min_ed - st - 2
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
grid_t = image_grid_thw[video_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (
torch.arange(grid_t) * second_per_grid[video_idx].cpu().float() * position_id_per_seconds
).long()
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
video_chunk_indexes = get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
audio_chunk_indexes = get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
sub_len = 0
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None
audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None
if video_chunk_index is not None:
sub_len += video_chunk_index[1] - video_chunk_index[0]
llm_pos_ids_list.append(
video_llm_pos_ids[:, video_chunk_index[0]: video_chunk_index[1]]
)
if audio_chunk_index is not None:
sub_len += audio_chunk_index[1] - audio_chunk_index[0]
llm_pos_ids_list.append(
audio_llm_pos_ids[:, audio_chunk_index[0]: audio_chunk_index[1]]
)
video_len = image_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
return position_ids, mrope_position_deltas
def qwen2_5_vl_get_rope_index(
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
args = get_args()
spatial_merge_size = getattr(args.mm.model.image_encoder.vision_encoder, 'spatial_merge_size', 2)
image_token_id = 151655
video_token_id = 151656
vision_start_token_id = getattr(args.mm.model, 'vision_start_token_id', 151652)
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
image_grid_thw[video_index][0],
image_grid_thw[video_index][1],
image_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = expanded_range * second_per_grid_t * getattr(args.mm.model.image_encoder.vision_encoder, 'tokens_per_second', 2)
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def glm_get_rope_index(
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
spatial_merge_size = 2
image_token_id = 151343
vision_start_token_id = 151341
vision_end_token_id = 151342
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
input_tokens = input_ids.tolist()
input_token_type = []
video_check_flg = False
for token in input_tokens:
if token == vision_start_token_id:
video_check_flg = True
elif token == vision_end_token_id:
video_check_flg = False
if token == image_token_id and not video_check_flg:
input_token_type.append("image")
elif token == image_token_id and video_check_flg:
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group = []
for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
group = list(group)
start_index = group[0][0]
end_index = group[-1][0] + 1
input_type_group.append((key, start_index, end_index))
llm_pos_ids_list = []
video_frame_num = 1
image_index, video_index = 0, 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
image_index += 1
video_frame_num = 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
video_index += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
video_frame_num = 1
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def qwen2_position(config,
input_ids,
image_grid_thw=None,
video_grid_thw=None,
attention_mask=None,
**kwargs):
position_ids, rope_deltas = qwen2vl_get_rope_index(config, input_ids, image_grid_thw, video_grid_thw,
kwargs.get('second_per_grid_ts', None), attention_mask=attention_mask)
return position_ids, rope_deltas
def qwen2_5_position(config,
input_ids,
image_grid_thw=None,
video_grid_thw=None,
attention_mask=None,
**kwargs):
position_ids, rope_deltas = qwen2_5_vl_get_rope_index(config, input_ids, image_grid_thw, video_grid_thw,
kwargs.get('second_per_grid_ts', None), attention_mask=attention_mask)
return position_ids, rope_deltas
def qwen2_5_omni_position(config,
input_ids,
image_grid_thw=None,
video_grid_thw=None,
attention_mask=None,
use_audio_in_video=False,
**kwargs):
position_ids, rope_deltas = qwen2_5_omni_get_rope_index(config, input_ids, image_grid_thw, video_grid_thw,
kwargs.get('video_second_per_grid', None), attention_mask=attention_mask,
use_audio_in_video=use_audio_in_video,
audio_seqlens=torch.sum(kwargs.get('feature_attention_mask'), dim=1))
return position_ids, rope_deltas
def glm_position(config=None,
input_ids=None,
image_grid_thw=None,
video_grid_thw=None,
attention_mask=None,
**kwargs):
attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
position_ids, rope_deltas = glm_get_rope_index(
config,
input_ids,
image_grid_thw,
video_grid_thw,
attention_mask=attention_mask_tensor,
)
return position_ids, rope_deltas
def _build_attentionmask_positionid_qwenllm(config, input_ids, attention_mask, inference_params, position_ids, pos_func,
*args, **kwargs):
seq_len = input_ids.shape[1]
if config.sequence_parallel:
seq_len *= mpu.get_tensor_model_parallel_world_size()
if inference_params is not None:
past_seen_tokens = attention_mask.shape[1] - 1 if inference_params.key_value_memory_dict else 0
else:
past_seen_tokens = 0
if position_ids is None:
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
rope_deltas = kwargs.get('rope_deltas', None)
inputs_embeds = kwargs.get('inputs_embeds', None)
cache_position = kwargs.get('cache_position', None)
if (
(cache_position is not None and cache_position[0] == 0)
or rope_deltas is None
):
position_ids, rope_deltas = pos_func(config,
input_ids,
attention_mask=attention_mask,
**kwargs)
rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
args = get_args()
if args.context_parallel_size > 1 and args.context_parallel_algo == "megatron_cp_algo":
attention_mask = None
return attention_mask, position_ids
elif getattr(args.mm.model.text_decoder, 'use_remove_padding', False) and config.use_flash_attn:
from mindspeed.core.transformer.flash_attention.generate_mask.generate_mask import get_attention_mask
attention_mask = get_attention_mask(config)
return attention_mask, position_ids
seq_len = input_ids.shape[1]
past_seen_token = 0
cache_position = kwargs.get('cache_position', None)
if cache_position is None:
cache_position = torch.arange(
past_seen_token, past_seen_token + seq_len, device=input_ids.device)
dtype, device = torch.bfloat16, input_ids.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_ids.shape[1]
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_token + sequence_length + 1
)
batch_size = input_ids.shape[0]
if attention_mask is not None and attention_mask.dim() == 4:
return attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
return causal_mask < 0, position_ids
def _build_attentionmask_positionid_internllm(attention_mask, position_ids, dtype=torch.float32, device=torch.device("npu"), past_key_values_length=0, *args, **kwargs):
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
if get_args().context_parallel_size > 1 and get_args().context_parallel_algo == "megatron_cp_algo":
attention_mask = None
return attention_mask, position_ids
input_shape = attention_mask.shape
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to(device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask.bool(), position_ids
def _build_attentionmask_positionid_glm4v(config, input_ids, attention_mask, position_ids, pos_func,
*args, **kwargs):
rope_deltas = kwargs.get('rope_deltas', None)
inputs_embeds = kwargs.get('inputs_embeds', None)
cache_position = kwargs.get('cache_position', None)
if position_ids is None:
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
if (
(cache_position is not None and cache_position[0] == 0)
or rope_deltas is None
):
position_ids, rope_deltas = pos_func(config,
input_ids,
attention_mask=attention_mask,
**kwargs)
rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
seq_len = input_ids.shape[1]
past_key_values = None
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + seq_len, device=input_ids.device
)
config._attn_implementation = "sdpa"
from transformers.masking_utils import create_causal_mask
causal_mask = create_causal_mask(
config=config,
input_embeds=input_ids,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
if causal_mask is None or get_args().use_flash_attn:
causal_mask = torch.triu(
torch.ones([seq_len, seq_len], dtype=torch.bool, device=input_ids.device), diagonal=1)
return causal_mask, position_ids
attention_mask_list = {'qwen2lm': partial(_build_attentionmask_positionid_qwenllm, pos_func=qwen2_position),
'qwen2_5_lm': partial(_build_attentionmask_positionid_qwenllm, pos_func=qwen2_5_position),
'internllm': _build_attentionmask_positionid_internllm,
'deepseek': _build_attentionmask_positionid_internllm,
'qwen3_lm': partial(_build_attentionmask_positionid_qwenllm, pos_func=qwen2_5_position),
'qwen2_5_omni_thinker': partial(_build_attentionmask_positionid_qwenllm, pos_func=qwen2_5_omni_position),
'glm4v_lm': partial(_build_attentionmask_positionid_glm4v, pos_func=glm_position),
'videoalign_lm': partial(_build_attentionmask_positionid_qwenllm, pos_func=qwen2_position),
}
def prepare_positionsids_mask_for_llm(config=None, input_ids=None, inference_params=None, attention_mask=None, position_ids=None, *args, **kwargs):
global_args = get_args()
llm_model_id = getattr(global_args.mm.model.text_decoder, 'model_id', None)
if llm_model_id and llm_model_id in attention_mask_list:
return attention_mask_list[llm_model_id](config=config, input_ids=input_ids, inference_params=inference_params,
attention_mask=attention_mask, position_ids=position_ids, **kwargs)
else:
raise AssertionError(f'the attention mask for llm is not build or model_id has error {llm_model_id}')