"""Transform for QwenVL"""
import numpy as np
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from qwenvl_tokenizer import IMG_TOKEN_SPAN
IGNORE_TOKEN_ID = -100
@MindFormerRegister.register(MindFormerModuleType.TRANSFORMS)
class QwenVLTransform:
"""
Caption Transform, preprocess captions and tokenize it,
align with torch impl.
"""
def __init__(self, tokenizer,
prompt=None,
max_img_size=IMG_TOKEN_SPAN,
padding="max_length",
max_length=512,
max_annotation=None,
random_seed=2022, truncation=True, add_special_tokens=True):
self.tokenizer = tokenizer
self.prompt = prompt
self.max_img_size = max_img_size
self.max_length = max_length
self.padding = padding
self.random_seed = random_seed
self.truncation = truncation
self.add_special_tokens = add_special_tokens
self.max_annotation = max_annotation
if prompt is None:
prompt = 'Describe the image in English'
self.img_padding = self.tokenizer.image_pad_tag * self.max_img_size
placeholder = '{}'
self.template = {
'caption': f'<img></img>{prompt}: {placeholder}',
'vqa': f'<img></img>{placeholder} Answer: {placeholder}',
}
def __call__(self, caption, template=None):
if template is None:
template = self.template
out = self.pre_caption(caption, template)
if len(out) == 2:
cap_out, img_start_pos = out
cap_out = np.stack(cap_out).astype(np.int32)
img_start_pos = np.stack(img_start_pos).astype(np.int32)
return cap_out, img_start_pos
cap_out, img_start_pos, label = out
cap_out = np.stack(cap_out).astype(np.int32)
img_start_pos = np.stack(img_start_pos).astype(np.int32)
label = np.stack(label).astype(np.int32)
return cap_out, img_start_pos, label
def pre_caption_for_caption_task(self, caption, template):
"""process caption for caption_task"""
task = 'caption'
if isinstance(caption, dict):
caption = caption[task]
if isinstance(caption, list) and (len(caption) > 1):
caption_list = [self.pre_caption(cap, template) for cap in caption]
caption_list = list(zip(*caption_list))
return caption_list
caption = template[task].format(caption)
return self.tokenizer(caption, max_length=self.max_length)["input_ids"]
def pre_caption_for_vqa_task(self, caption, template):
"""process caption for caption_task"""
task = 'vqa'
if isinstance(caption['answers'], list) and (len(caption) > 1):
caption_list = []
for ans in caption['answers']:
qa = {'question': caption['question'], 'answers': ans}
caption_list.append(self.pre_caption(qa, template))
caption_list = list(zip(*caption_list))
return caption_list
question = caption['question']
caption = caption['answers']
caption = template[task].format(question, caption)
return self.tokenizer(caption, max_length=self.max_length, padding=self.padding)["input_ids"]
def generate_input_and_label_for_sft_data(self, caption, raw_data, raw_data_role):
"""generate input ids and labels for sft data"""
input_ids = []
labels = []
user_role_name = 'user' if caption.get('user_role_name') is None else caption.get('user_role_name')
for i, cap in enumerate(raw_data):
img_pad_token_id = self.tokenizer.image_pad_tag * self.max_img_size
cap = cap.replace('{}', img_pad_token_id)
tokenized_cap = self.tokenizer(cap)['input_ids']
input_ids.extend(tokenized_cap)
if raw_data_role[i] == user_role_name or raw_data_role[i] == 'system':
ignore_token = [IGNORE_TOKEN_ID] * (len(tokenized_cap) - 3)
labels.extend(self._add_start_end_label(ignore_token))
elif raw_data_role[i] == 'assistant' if caption.get(
'assistant_role_name') is None else caption.get('user_role_name'):
data_role_input_ids = self.tokenizer('<|im_start|>' + raw_data_role[i])['input_ids']
has_ignored_label = [IGNORE_TOKEN_ID] * len(data_role_input_ids) \
+ tokenized_cap[len(data_role_input_ids) + 1:-2]
labels.extend(self._add_start_end_label(has_ignored_label))
else:
raise ValueError(f"raw_data_role {raw_data_role[i]} is invalid")
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
labels += [IGNORE_TOKEN_ID] * (self.max_length - len(labels))
input_ids = input_ids[:self.max_length]
labels = labels[:self.max_length]
return input_ids, labels
def generate_img_pos(self, input_ids, img_idx):
"""padding and generate coord for image"""
new_img_idx = []
img_start_temp = []
for i, token_id in enumerate(input_ids):
if token_id == self.tokenizer.img_start_id:
img_start_temp.append(i + 1)
if not img_start_temp:
new_img_idx = [self.max_length - self.max_img_size - 1] * len(img_idx)
else:
for i, idx in enumerate(img_idx):
if idx != -1:
new_img_idx.append(img_start_temp[i])
else:
new_img_idx.append(self.max_length - self.max_img_size - 1)
return self._generate_coord(new_img_idx)
def pre_caption(self, caption, template):
"""
Caption preprocessing removes any punctuation marks except commas,
tailing spaces and transform sentence into lower case.
"""
task = caption.get('task')
if task is None:
raise ValueError("task is required in the output of dataloader")
if task == 'sft':
raw_data = caption.get('raw_data')
raw_data_role = caption.get('raw_data_role')
img_idx = caption.get('img_idx')
if len(img_idx) * self.max_img_size > self.max_length - 1:
raise ValueError(
f"Sum of query token length of max_img_size images has exceeds the max_length of text. "
f"max_img_size={self.max_img_size}, max_length={self.max_length}")
if raw_data is None or raw_data_role is None or img_idx is None:
raise ValueError("raw_data, raw_data_role and img_idx are required")
input_ids, labels = self.generate_input_and_label_for_sft_data(caption, raw_data, raw_data_role)
coord = self.generate_img_pos(input_ids, img_idx)
return input_ids, coord, labels
if task == 'vqa':
input_ids = self.pre_caption_for_vqa_task(caption, template)
else:
input_ids = self.pre_caption_for_caption_task(caption, template)
input_ids = np.array(input_ids, dtype=np.int32)
img_start_pos = np.where(input_ids == self.tokenizer.img_start_id)[0] + 1
coord = self._generate_coord(img_start_pos)
return input_ids, coord
def _add_start_end_label(self, input_list):
im_start = self.tokenizer.im_start_id
im_end = self.tokenizer.im_end_id
nl_token_id = self.tokenizer('\n')["input_ids"]
input_list = [im_start] + input_list + [im_end] + nl_token_id
return input_list
def _generate_coord(self, img_start_pos):
num_img = len(img_start_pos)
coord = np.zeros((num_img, self.max_img_size, 2), np.int32)
for idx, pos in enumerate(img_start_pos):
for img_pos in range(self.max_img_size):
coord[idx, img_pos] = [0, pos + img_pos]
return coord