"""
QwenVLProcessor
"""
import re
from typing import Optional, Union, List
import PIL
import PIL.Image
import numpy as np
import mindspore as ms
from mindformers import MindFormerModuleType, MindFormerRegister, logger
from mindformers.dataset.transforms.vision_transforms import BatchPILize, BatchToTensor, BatchNormalize
from mindformers.models.base_processor import BaseProcessor
from mindformers.models.image_processing_utils import BaseImageProcessor
from mindformers.models.multi_modal.base_multi_modal_processor import BatchResizeV2
from mindformers.models.multi_modal.modal_content import (
ModalContentTransformTemplate,
BaseTextContentBuilder,
BaseImageContentBuilder
)
from mindformers.models.multi_modal.utils import DataRecord
from mindformers.tools.image_tools import load_image
from qwenvl_transform import QwenVLTransform
class QwenVLImageProcessor(BaseImageProcessor):
"""
QwenVLImageProcessor.
Args:
image_size (int): The target size.
"""
def __init__(self,
image_size: Optional[int] = 224,
interpolation: Optional[str] = 'bicubic',
mean=None,
std=None,
is_hwc=False,
**kwargs):
self.pilize = BatchPILize()
super().__init__(**kwargs)
if isinstance(image_size, int):
image_size = (image_size,) * 2
self.resize = BatchResizeV2(image_size, interpolation=interpolation)
self.to_tensor = BatchToTensor()
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = BatchNormalize(mean, std, is_hwc)
def preprocess(self, images: Union[ms.Tensor, PIL.Image.Image, np.ndarray, List[PIL.Image.Image]], **kwargs):
r"""
Preprocess Required By Base Processor.
Args:
images (ms.Tensor, PIL.Image, numpy.array, List[PIL.Image]): A batch of images.
Return:
A 4-rank tensor for a batch of images.
"""
images = self.pilize(images)
images = self.resize(images)
images = self.to_tensor(images)
images = self.normalize(images)
kwargs.pop("other", None)
if isinstance(images, list):
return ms.Tensor(np.row_stack([np.expand_dims(item, axis=0) for item in images]))
if len(images.shape) == 4:
return ms.Tensor(images)
return ms.Tensor(np.expand_dims(images, axis=0))
@staticmethod
def _bhwc_check(image_batch: Union[ms.Tensor, PIL.Image.Image, np.ndarray, List[PIL.Image.Image]]):
r"""Bhwc_check"""
if isinstance(image_batch, np.ndarray):
if image_batch.shape[-1] == 3:
return True
if isinstance(image_batch, ms.Tensor):
if image_batch.asnumpy().shape[-1] == 3:
return True
if isinstance(image_batch, (list, PIL.Image.Image)):
return True
return False
class QwenVLImageContentBuilder(BaseImageContentBuilder):
"""
QwenVL Image Content Builder.
"""
def __init__(
self,
context_pad_token,
context_length,
image_size=448,
image_location="",
start_token="<img>",
end_token="</img>",
tokenizer=None,
modal_content_max_size=1,
mode="predict",
max_length=2048
):
super().__init__(
context_pad_token=context_pad_token,
context_length=context_length,
use_custom_token=False,
start_token=start_token,
end_token=end_token,
tokenizer=tokenizer,
need_padding_context=False,
modal_content_max_size=modal_content_max_size,
mode=mode,
max_length=max_length
)
self.image_location = image_location
self.start_token_id = 151857
self.end_token_id = 151858
self.image_mapping = BatchResizeV2((image_size, image_size), interpolation="cubic")
def regular_input_for_predict(self, inputs, result_recorder: DataRecord = None, **kwargs):
raise NotImplementedError
def regular_input_for_train(self, inputs, result_recorder: DataRecord = None, **kwargs):
return super().regular_input_for_train(inputs, result_recorder=result_recorder, **kwargs)
class QwenVLTextContentBuilder(BaseTextContentBuilder):
"""
QwenVL Text Content Builder.
"""
def __init__(self):
super().__init__()
self.ref_start_tag = "<ref>"
self.ref_end_tag = "</ref>"
self.box_start_tag = "<box>"
self.box_end_tag = "</box>"
def regular_input_for_train(self, inputs, result_recorder: DataRecord = None, **kwargs):
return inputs
@MindFormerRegister.register(MindFormerModuleType.TRANSFORMS)
class QwenVLContentTransformTemplate(ModalContentTransformTemplate):
"""
QwenVL Modal Content Transform Template
"""
def __init__(self, output_columns, tokenizer, image_size=448, num_queries=256, dataset_dir="", mode="predict",
modal_content_padding_size=1, **kwargs):
super().__init__(output_columns=output_columns, tokenizer=tokenizer, mode=mode,
modal_content_padding_size=modal_content_padding_size, **kwargs)
self.dataset_dir = dataset_dir
self.modal_builders = {
"image": QwenVLImageContentBuilder("<imgpad>", num_queries, start_token="<img>", end_token="</img>",
image_location=dataset_dir,
modal_content_max_size=modal_content_padding_size,
mode=mode, max_length=self.max_length, image_size=image_size),
"text": QwenVLTextContentBuilder()
}
self.system_message = kwargs.get("system_message", "You are a helpful assistant.")
self.user_role_name = kwargs.get("user_role_name", "user")
self.user_prompt = kwargs.get("user_prompt", "")
self.assistant_role_name = kwargs.get("assistant_role_name", "assistant")
self.assistant_prompt = kwargs.get("assistant_prompt", "")
self.assistant_token_ids_length = len(self.tokenizer(f"{self.assistant_role_name}")["input_ids"])
self.ignore_token_id = -100
self.prompt_map = {
self.user_role_name: self.user_prompt,
self.assistant_role_name: self.assistant_prompt
}
def build_conversation_input_text(self, raw_inputs, result_recorder: DataRecord):
if self.mode == "train":
return self.build_sft_conversation_input(raw_inputs, result_recorder)
raise NotImplementedError("build_conversation_input_text is only support train mode.")
def build_sft_conversation_input(self, conversations: List[List], result_recorder: DataRecord):
"""build sft conversation inputs"""
text_list = [f"<|im_start|>system\n{self.system_message}<|im_end|>\n"]
role_info = ["system"]
for conversation in conversations:
from_, value = conversation
if from_ in (self.user_role_name, self.assistant_role_name):
prompt = self.prompt_map.get(from_)
else:
logger.warning("role_name `%s` is invalid in conversation %s, it will be ignored!", from_, conversation)
continue
text_list.append(f"<|im_start|>{from_}\n{prompt}{value}<|im_end|>\n")
role_info.append(from_)
result_recorder.put("role_info", role_info)
return text_list
def build_labels(self, text_id_list, result_recorder: DataRecord, **kwargs):
"""build labels for qwenvl"""
role_info_list = result_recorder.get("role_info")
labels = []
for index, role_name in enumerate(role_info_list):
labels_item = text_id_list[index].copy()
if role_name in (self.user_role_name, "system"):
labels_item[1:-2] = self.ignore_token_id
elif role_name == self.assistant_role_name:
labels_item[1:self.assistant_token_ids_length + 2] = self.ignore_token_id
else:
raise ValueError(f"role_name `{role_name}` is invalid")
labels.extend(labels_item)
return labels
def get_need_update_output_items(self, result: DataRecord):
update_items = {"images": self.modal_builders["image"].padding_images_to_max_content_size(result.get("images"))}
return update_items
class QwenVLProcessor(BaseProcessor):
r"""QwenVL Processor,
consists of a feature extractor (BaseFeatureEXtractor) for image input,
and a tokenizer for text input.
Args:
image_processor (BaseImageProcessor): Used for process image data.
tokenizer: Used for process text data.
max_length (Optional[int]): The length of text tokens.
padding (Optional[str]): The padding strategy of tokenizer, [None, "max_length"].
return_tensors (Optional[str]): The type of returned tensors for tokenizer, [None, "ms"].
"""
def __init__(self, image_processor, tokenizer,
max_length=512,
image_padding_size=256,
prompt=None,
padding='max_length', return_tensors='ms', **kwargs):
super().__init__(
image_processor=image_processor,
tokenizer=tokenizer,
max_length=max_length,
padding=padding,
return_tensors=return_tensors, **kwargs)
self.text_transform = QwenVLTransform(tokenizer,
max_img_size=image_padding_size,
max_length=max_length,
prompt=prompt)
self.padding_size = self.text_transform.max_img_size
@staticmethod
def process_text(text):
"""process text, including padding text and extracting image path"""
start_tag_index = []
end_tag_index = []
for match in re.finditer(r'<img>', text):
start_tag_index.append((match.start(), match.end()))
for match in re.finditer(r'</img>', text):
end_tag_index.append((match.start(), match.end()))
if len(start_tag_index) != len(end_tag_index):
raise ValueError("the text has unclosed image tag")
replaced_text = []
img_path = []
last_end = 0
for start_tag_index_item, end_tag_index_item in zip(start_tag_index, end_tag_index):
start_tag_start_idx, start_tag_end_idx = start_tag_index_item
end_tag_start_idx, end_tag_end_idx = end_tag_index_item
if start_tag_end_idx > end_tag_start_idx:
raise ValueError("the text has error image tag")
replaced_text.append(text[last_end:start_tag_start_idx])
img_path.append(text[start_tag_end_idx:end_tag_start_idx])
last_end = end_tag_end_idx
replaced_text.append(text[last_end:])
img_padding = "<img></img>"
padded_text = img_padding.join(replaced_text)
return padded_text, img_path
def process_query(self, query_ele_list, task):
"""parse query, tokenize and transform text, load images and generate image pos in text"""
query_text = self.tokenizer.from_list_format(query_ele_list)
padded_text, img_path_list = self.process_text(query_text)
text_input_id, img_pos = self.text_transform({"task": task, task: padded_text}, template={task: "{}"})
image_in_a_text = self.image_processor([load_image(img_path_item) for img_path_item in img_path_list])
return text_input_id, image_in_a_text.asnumpy(), img_pos
@staticmethod
def padding_images(batch_images_list, batch_img_pos_list, max_img_len):
"""padding image and img_pos to max_img_len in a batch"""
padded_batch_images = []
padded_batch_img_pos = []
for image, image_pos in zip(batch_images_list, batch_img_pos_list):
image_size = image.shape[0]
if image_size == max_img_len:
padded_batch_images.append(image)
padded_batch_img_pos.append(image_pos)
continue
repeat = [1] * image_size
repeat[-1] = max_img_len - image_size + 1
padded_batch_images.append(np.repeat(image, repeat, axis=0))
padded_batch_img_pos.append(np.repeat(image_pos, repeat, axis=0))
return padded_batch_images, padded_batch_img_pos
def post_process(self, output_ids, queries):
"""post process the origin output ids, it converts <imgpad> token to origin image path"""
output = []
for output_ids_item, query in zip(output_ids, queries):
output_item = self.tokenizer.post_process(output_ids_item, query)
output.append(output_item)
return output
def __call__(self, image_input=None, text_input=None, task="caption"):
"""call function"""
if isinstance(text_input, list) and text_input and isinstance(text_input[0], dict):
text_input = [text_input]
max_img_len = 0
batch_text_ids = []
batch_images_list = []
batch_img_pos_list = []
for text_input_item in text_input:
text_in_a_query, image_in_a_query, img_pos_in_a_query = self.process_query(text_input_item, task)
max_img_len = max(max_img_len, image_in_a_query.shape[0])
batch_text_ids.append(text_in_a_query)
batch_images_list.append(image_in_a_query)
batch_img_pos_list.append(img_pos_in_a_query)
padded_batch_image, padded_batch_img_pos = self.padding_images(batch_images_list,
batch_img_pos_list,
max_img_len)
return {
"input_ids": np.stack(batch_text_ids, axis=0),
"image": ms.Tensor(np.stack(padded_batch_image, axis=0), ms.float32),
"img_pos": ms.Tensor(np.stack(padded_batch_img_pos, axis=0), ms.int32)
}