from typing import Any, List, Dict, Optional, Tuple
import torch
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers.generation.streamers import TextStreamer
from mindspeed_mm.tasks.inference.pipeline.pipeline_mixin.generation_mixin import GenerationMixin
from mindspeed_mm.models.text_encoder import Tokenizer
from mindspeed_mm.tasks.inference.pipeline.parallel_wrapper import ParallelWrapper
class Qwen2VlPipeline(GenerationMixin):
def __init__(self, infer_config):
self.infer_config = infer_config
self.tokenizer = Tokenizer(infer_config.tokenizer).get_tokenizer()
from pretrain_vlm import model_provider
self.model = ParallelWrapper(model_provider)
self.image_processor = AutoProcessor.from_pretrained(infer_config.tokenizer.from_pretrained,
local_files_only=True)
self.generation_config = infer_config.generation_config
self.model_config = infer_config.text_decoder
self.main_input_name = 'input_ids'
self.min_pixels = infer_config.min_pixels if hasattr(infer_config, "min_pixels") else None
self.max_pixels = infer_config.max_pixels if hasattr(infer_config, "max_pixels") else None
def __call__(self, prompt=None, images=None, videos=None, return_ids=False, skip_special_tokens=True):
if images:
if isinstance(images, list):
image = images[0]
else:
image = images
else:
image = self.infer_config.image_path if hasattr(self.infer_config, "image_path") else None
if videos:
if isinstance(videos, list):
video = videos[0]
else:
video = videos
else:
video = self.infer_config.video_path if hasattr(self.infer_config, "video_path") else None
if not prompt:
prompt = self.infer_config.prompts
inputs = self.prepare_inputs(prompt=prompt, images=image, videos=video)
if inputs is None:
inputs = {'input_ids': self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.infer_config.device)}
streamer = None if return_ids else TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=skip_special_tokens)
generated_ids = self.generate(**inputs,
do_sample=True if self.generation_config.temperature > 0 else False,
temperature=self.generation_config.temperature,
max_new_tokens=self.generation_config.max_new_tokens,
streamer=streamer)
self.model.inference_params = None
if return_ids and generated_ids is not None:
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs['input_ids'], generated_ids)
]
out = self.image_processor.tokenizer.batch_decode(
generated_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False
)
response = out[0]
return response
else:
return None
def prepare_inputs(self, prompt=None, images=None, videos=None, messages=None):
if not images and not messages and not videos:
return None
if not messages:
content = []
if images:
content.append({"type": "image", "image": images})
if videos:
content.append({"type": "video", "video": videos})
content.append({"type": "text", "text": prompt})
messages = [[
{
"role": "user",
"content": content,
}
]]
prompt = self.image_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.image_processor(
text=prompt,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.infer_config.device)
if image_inputs:
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
if video_inputs:
inputs['pixel_values'] = inputs['pixel_values_videos'].unsqueeze(0)
inputs['image_grid_thw'] = inputs['video_grid_thw']
return inputs
def get_rope_index(
self,
input_ids: torch.LongTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [3, 4, 5, 6, 7]
text height position_ids: [3, 4, 5, 6, 7]
text width position_ids: [3, 4, 5, 6, 7]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
spatial_merge_size = self.infer_config.image_encoder.vision_encoder.spatial_merge_size
image_token_id = self.generation_config.image_token_id
video_token_id = self.generation_config.video_token_id
vision_start_token_id = self.generation_config.vision_start_token_id
mrope_position_deltas = []
if image_grid_thw is not None or video_grid_thw is not None:
total_input_ids = input_ids
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
if attention_mask is not None:
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def prepare_inputs_for_generation(self, **kwargs):
input_ids = kwargs.get("input_ids", None)
if "cache_position" not in kwargs:
if "inputs_embeds" in kwargs:
cache_position = torch.ones_like(kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
kwargs["cache_position"] = cache_position
cache_position = kwargs.get("cache_position", None)
video_grid_thw = kwargs.get("video_grid_thw", None)
image_grid_thw = kwargs.get("image_grid_thw", None)
rope_deltas = kwargs.get("rope_deltas", None)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
pixel_values = kwargs.get("pixel_values", None)
model_inputs = {}
if self.model.inference_params is not None:
batch_size, seq_length = input_ids.shape
input_ids = input_ids[:, [-1]]
pixel_values = None
attention_mask = torch.ones((batch_size, seq_length)).to("npu")
if attention_mask is not None and position_ids is None:
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
else:
batch_size, seq_length = input_ids.shape
delta = (
cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
)
position_ids = torch.arange(seq_length, device=input_ids.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
model_inputs.update(
{
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_grid_thw": image_grid_thw,
"cache_position": cache_position,
"rope_deltas": rope_deltas
}
)
return model_inputs
def _update_model_kwargs_for_generation(self, model_kwargs: Dict[str, Any], model_inputs: Dict[str, Any]):
if model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = model_inputs["cache_position"][-1:] + 1
else:
past_positions = model_inputs.pop("cache_position")
new_positions = torch.arange(
past_positions[-1] + 1, past_positions[-1] + 2, dtype=past_positions.dtype
).to(past_positions.device)
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
if "rope_deltas" in model_inputs:
model_kwargs["rope_deltas"] = model_inputs["rope_deltas"]
return model_kwargs
def evaluate(self, message):
messages = [{'role': 'user', 'content': self._prepare_content(message)}]
inputs = self.prepare_inputs(messages=[messages])
generated_ids = self.generate(**inputs,
do_sample=True if self.generation_config.temperature > 0 else False,
temperature=self.generation_config.temperature,
max_new_tokens=self.generation_config.max_new_tokens)
self.model.inference_params = None
if generated_ids is not None:
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
]
out = self.image_processor.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = out[0]
return response
else:
return None
def _prepare_content(self, inputs: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
inputs list[dict[str, str]], each dict has keys: ['type', 'value']
"""
content = []
for s in inputs:
if s['type'] == 'image':
item = {'type': 'image', 'image': 'file://' + s['value']}
if self.min_pixels is not None:
item['min_pixels'] = self.min_pixels
if self.max_pixels is not None:
item['max_pixels'] = self.max_pixels
elif s['type'] == 'text':
item = {'type': 'text', 'text': s['value']}
else:
raise ValueError(f"Invalid message type: {s['type']}, {s}")
content.append(item)
return content