"""QwenVL models' APIs."""
import math
from typing import Optional
import numpy as np
import mindspore as ms
from mindspore import dtype as mstype
from mindspore import nn, Parameter, Tensor
from mindspore import ops
from mindspore.common.initializer import initializer, TruncatedNormal, Normal
from mindspore.ops import operations as P
from mindformers import (
MultiHeadAttention,
MindFormerRegister,
MindFormerModuleType,
PreTrainedModel,
TransformerOpParallelConfig
)
from mindformers.models import build_network
from mindformers.models.utils import lazy_inline
from mindformers.models.vit.vit_modules import get_2d_sincos_pos_embed
from mindformers.modules.activation import GELU
from mindformers.modules.layers import LayerNorm, Linear
from mindformers.tools.logger import logger
from mindformers.models.utils import LayerSetting
from qwenvl_config import QwenVLConfig, QwenVLVisionConfig
class AbsPos(nn.Cell):
r"""
Module to resize position embedding if src size do not equal to target size
Args:
src_size(int): the src size
tgt_size(int): the target size
Returns:
x: Tensor, the resized input if the input needs to be resized, otherwise return the original input
"""
def __init__(self, src_size: int, tgt_size: int):
super().__init__()
self.src_size = int(math.sqrt(src_size))
self.tgt_size = int(math.sqrt(tgt_size))
self.cast = P.Cast()
self.reshape = P.Reshape().shard(((1, 1),))
self.flatten = nn.Flatten(start_dim=0, end_dim=2)
self.resize_shape = ms.Tensor([self.tgt_size, self.tgt_size], ms.int32)
self.resize = P.ResizeBicubic(align_corners=False, half_pixel_centers=False)
self.transpose = P.Transpose().shard(((1, 1, 1, 1),))
def construct(self, x: Tensor):
"""forward of AbsPos"""
if self.src_size != self.tgt_size:
ori_dtype = x.dtype
x = self.reshape(x, (1, self.src_size, self.src_size, -1))
x = self.transpose(x, (0, 3, 1, 2))
x = self.cast(x, ms.float32)
x = self.resize(x, self.resize_shape)
x = self.cast(x, ori_dtype)
x = self.transpose(x, (0, 2, 3, 1))
x = self.flatten(x)
return x
class Resampler(nn.Cell):
"""
A 2D perceiver-resampler network with one cross attention layers by num_queries learnable queries and 2d
sin_cos pos_emb
Args:
image_size(int): Size of image.
patch_size(int): Patch size of image.
hidden_size(int): The dimension of embed.
num_queries(int): Nums of query tokens.
output_dim(int): The target embed dim of output.
parallel_config(TransformerOpParallelConfig): The parallel configure.
compute_dtype: The type of Linear computation module.
param_init_type: The parameter initialization type of the module.
softmax_compute_type: The type of softmax computation module.
Returns:
out: A tensor with the shape of (bs, num_queries, output_dim)
"""
def __init__(self, image_size: int,
patch_size: int,
hidden_size: int,
num_queries: int,
output_dim: int,
parallel_config: TransformerOpParallelConfig,
compute_dtype,
param_init_type,
softmax_compute_type):
super().__init__()
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.num_queries = num_queries
self.query_grid_size = int(math.sqrt(num_queries))
self.embed_dim = output_dim
self.num_heads = self.embed_dim // 128
self.kv_dim = hidden_size
self.pos_embed = Parameter(get_2d_sincos_pos_embed(self.embed_dim, self.query_grid_size), requires_grad=False,
parallel_optimizer=False)
self.query = Parameter(initializer(TruncatedNormal(mean=0.0, sigma=0.02), [self.num_queries, self.embed_dim]),
requires_grad=False, parallel_optimizer=False)
if self.kv_dim is not None and self.kv_dim != self.embed_dim:
self.kv_proj = Linear(in_channels=self.kv_dim, out_channels=self.embed_dim,
has_bias=False,
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.kv_proj.shard(strategy_matmul=((dp, 1),
(mp, 1)),
strategy_bias=((dp, mp), (mp,))
)
else:
self.kv_proj = nn.Identity()
self.img_grid_size = image_size // patch_size
self.attn = MultiHeadAttention(hidden_size=self.embed_dim,
num_heads=self.num_heads,
batch_size=None,
src_seq_length=self.num_queries,
tgt_seq_length=self.img_grid_size ** 2,
hidden_dropout_rate=0.0,
attention_dropout_rate=0.0,
softmax_compute_type=softmax_compute_type,
use_past=False,
param_init_type=param_init_type,
parallel_config=parallel_config.dp_mp_config)
self.ln_q = LayerNorm((self.embed_dim,), eps=1e-6)
self.ln_q.shard(((1, 1,),))
self.ln_kv = LayerNorm((self.embed_dim,), eps=1e-6)
self.ln_kv.shard(((dp, 1, 1),))
self.abs_pos = AbsPos(self.pos_embed.shape[0], self.img_grid_size ** 2)
self.shape = P.Shape()
self.tile = P.Tile().shard(((dp, 1, 1),))
self.add = P.Add().shard(((dp, 1, 1), (1, 1, 1)))
self.expand_dims = P.ExpandDims().shard(((dp, 1),))
self.expand_dims_no_shard = P.ExpandDims().shard(((1, 1),))
def construct(self, x, attn_mask=None):
"""forward of Resampler"""
bs, _, _ = self.shape(x)
pos_embed = self.abs_pos(self.pos_embed)
x = self.kv_proj(x)
x = self.ln_kv(x)
q = self.ln_q(self.query)
query_tensor = self.add(self.tile(q, (bs, 1, 1)), self.expand_dims_no_shard(self.pos_embed, 0))
key_tensor = self.add(x, self.expand_dims_no_shard(pos_embed, 0))
out = self.attn(
query_tensor,
key_tensor,
x,
attention_mask=attn_mask
)[0]
return out
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class QwenVLVisionModel(PreTrainedModel):
r"""VisionModel Of Qwen-VL
Args:
config (QwenVLVisionConfig): The config of VisionConfig for QwenVL
num_queries(int): num of query tokens
Returns:
input_x: A tensor with the shape of (bs, num_queries, output_dim)
"""
def __init__(self, config: QwenVLVisionConfig, num_queries: int = 256, **kwargs):
super().__init__(config, **kwargs)
self.num_queries = num_queries
parallel_config = config.parallel_config
hidden_size = config.hidden_size
dtype = config.compute_dtype
patch_size = config.patch_size
self.conv1 = \
nn.Conv2d(
in_channels=3, out_channels=hidden_size, kernel_size=patch_size,
stride=patch_size, has_bias=False, pad_mode='pad').to_float(dtype)
self.conv1.conv2d.shard(((parallel_config.data_parallel, 1, 1, 1), (1, 1, 1, 1)))
self.conv1.bias_add.shard(((parallel_config.data_parallel, 1, 1, 1), (1,)))
self.conv1.pipeline_stage = config.start_stage
scale = hidden_size ** -0.5
self.positional_embedding = \
Parameter(scale * Tensor(
np.random.normal(0, 1, size=(256, hidden_size))).astype(dtype),
parallel_optimizer=False)
self.positional_embedding.pipeline_stage = config.start_stage
self.ln_pre = LayerNorm((hidden_size,), eps=1e-6)
self.ln_pre.shard(((parallel_config.data_parallel, 1, 1),))
self.ln_pre.pipeline_stage = config.start_stage
logger.info(f"ln_pre pipeline_stage: {self.ln_pre.pipeline_stage}")
self.transformer = QwenVLTransformer(image_size=config.image_size,
patch_size=patch_size,
hidden_size=hidden_size,
intermediate_size=config.intermediate_size,
n_head=config.num_attention_heads,
layers=config.num_hidden_layers,
dtype=config.dtype,
softmax_compute_type=config.softmax_compute_type,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type,
gelu_dtype=config.gelu_dtype,
parallel_config=config.parallel_config,
use_flash_attention=config.use_flash_attention,
enable_fa_opt=config.enable_fa_opt)
if config.stage_num > 0:
self.layer_setting = LayerSetting(
config.num_hidden_layers,
config.offset,
config.parallel_config,
1,
config.start_stage,
config.stage_num)
for layer_id, layer in zip(range(config.num_hidden_layers), self.transformer.resblocks):
self.layer_setting(layer, layer_id)
self.attn_pool = Resampler(image_size=config.image_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_queries=self.num_queries,
output_dim=config.output_dim,
parallel_config=config.parallel_config,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type,
softmax_compute_type=config.softmax_compute_type)
if config.parallel_config.pipeline_stage > 1:
if config.stage_num > 0:
self.attn_pool.pipeline_stage = config.start_stage + config.stage_num - 1
else:
self.attn_pool.pipeline_stage = 0
self.transpose = P.Transpose().shard(((parallel_config.data_parallel, 1, 1),))
self.ln_post = LayerNorm((config.output_dim,), eps=1e-6)
if config.parallel_config.pipeline_stage > 1:
if config.stage_num > 0:
self.ln_post.pipeline_stage = config.start_stage + config.stage_num - 1
else:
self.ln_post.pipeline_stage = 0
self.ln_post.shard(((config.parallel_config.data_parallel, 1, 1),))
self.proj = \
Parameter(scale * Tensor(np.random.normal(0, 1,
size=(config.output_dim, config.output_dim))).astype(dtype))
self.proj.pipeline_stage = config.start_stage
self.dtype = dtype
self.cast = P.Cast()
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (1, 1)))
img_grid_size = config.image_size // patch_size
self.abs_pos = AbsPos(self.positional_embedding.shape[0], img_grid_size ** 2)
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.matmul = P.BatchMatMul().shard(((dp, 1, mp), (mp, 1)))
def construct(self, input_x: ms.Tensor):
"""forward of QwenVLVisionModel"""
input_x = self.conv1(input_x)
input_x = input_x.reshape(input_x.shape[0], input_x.shape[1], -1)
input_x = self.transpose(input_x, (0, 2, 1))
abs_pos = self.abs_pos(self.positional_embedding)
input_x = self.add(input_x, abs_pos)
input_x = self.ln_pre(input_x)
input_x = self.transformer(input_x)
input_x = self.attn_pool(input_x)
input_x = self.ln_post(input_x)
input_x = self.cast(input_x, self.dtype)
input_x = self.matmul(input_x, self.proj)
return input_x
class MLP(nn.Cell):
"""
A multilayer perceptron for ViT
"""
def __init__(self, layers: int, hidden_size: int, intermediate_size: int,
compute_dtype, param_init_type, gelu_dtype, parallel_config):
super().__init__()
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
proj_std = (hidden_size ** -0.5) * ((2 * layers) ** -0.5)
fc_std = (2 * hidden_size) ** -0.5
self.c_fc = Linear(hidden_size, intermediate_size, weight_init=Normal(mean=0.0, sigma=fc_std),
compute_dtype=compute_dtype, param_init_type=param_init_type)
self.c_fc.shard(strategy_matmul=((dp, 1), (mp, 1)), strategy_bias=((dp, mp), (mp,)))
self.c_proj = Linear(intermediate_size, hidden_size, weight_init=Normal(mean=0.0, sigma=proj_std),
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.c_proj.shard(strategy_matmul=((dp, 1), (mp, 1)), strategy_bias=((dp, mp), (mp,)))
self.gelu = GELU(approximate=False)
self.cast = P.Cast()
self.gelu_dtype = gelu_dtype
self.dtype = P.DType()
def construct(self, x):
x = self.c_fc(x)
ori_dtype = self.dtype(x)
x = self.cast(x, self.gelu_dtype)
x = self.gelu(x)
x = self.cast(x, ori_dtype)
x = self.c_proj(x)
return x
class VisualFlashAttention(nn.Cell):
"""
Flash Attention for visual module
"""
def __init__(self, fa, parallel_config, size_per_head, enable_fa_opt=False):
super().__init__()
self.fa = fa
self.enable_fa_opt = enable_fa_opt
if self.enable_fa_opt:
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.stridden_slice = P.StridedSlice().shard(((dp, 1, 1, 1),))
self.concat = P.Concat(axis=-1).shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
self.zeros = P.Zeros().shard(((dp, 1, 1, 1),))
mul_value, remainder = divmod(size_per_head, 128)
if remainder > 0:
new_size_per_head = (mul_value + 1) * 128
else:
new_size_per_head = size_per_head
self.pad_length = new_size_per_head - size_per_head
if self.pad_length == 0:
raise ValueError("size_per_head is divisible by 128, please disable enable_fa_opt")
def construct(self, query, key, value, attention_mask):
"""forward for VisualFlashAttention"""
bsz, num_head, seq, size_per_head = query.shape
if self.enable_fa_opt:
pad = self.zeros((bsz, num_head, seq, self.pad_length), P.DType()(query))
query = self.concat([query, pad])
key = self.concat([key, pad])
value = self.concat([value, pad])
weighted_values = self.fa(query, key, value, attention_mask)
if self.enable_fa_opt:
weighted_values = self.stridden_slice(weighted_values, (0, 0, 0, 0), (bsz, num_head, seq, size_per_head),
(1, 1, 1, 1))
return weighted_values
class VisualAttention(MultiHeadAttention):
def __init__(self, *args, use_attention_mask=False, enable_fa_opt=False, **kwargs):
super().__init__(*args, **kwargs)
parallel_config = kwargs.get('parallel_config')
if self.use_flash_attention and not use_attention_mask:
self.flash_attention = VisualFlashAttention(self.flash_attention, parallel_config, self.size_per_head,
enable_fa_opt=enable_fa_opt)
class ResidualAttentionBlock(nn.Cell):
r"""
ResidualAttentionBlock of QwenVLVisionModel
Args:
image_size(int): size of image
patch_size(int): patch size of image
hidden_size(int): the embed dim of input
hidden_size(int): The dimension of embed.
intermediate_size(int): The linear width in MLP.
n_head(int): The number of attention heads.
layers(int): The number of transformer layers for weight initialization.
dtype(mstype): The type of Linear computation module.
softmax_compute_type(mstype): The type of softmax computation module.
compute_dtype(mstype): The type of linear computation module.
param_init_type(mstype): The parameter initialization type of the module.
gelu_dtype(mstype): The type of gelu activation computation module.
parallel_config: The parallel configure.
use_flash_attention: Whether to use flash attention
enable_fa_opt: Whether to enable padding MatMul operation flash attention.
attn_mask (Optional[ms.Tensor]): attention mask.
"""
def __init__(self, image_size: int, patch_size: int, hidden_size: int, intermediate_size: int, n_head: int,
layers: int,
dtype: mstype,
softmax_compute_type,
compute_dtype,
param_init_type,
gelu_dtype,
parallel_config,
use_flash_attention: bool = False,
enable_fa_opt: bool = False,
attn_mask: Optional[ms.Tensor] = None):
super().__init__()
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.dtype = dtype
img_grid_size = image_size // patch_size
self.attn = VisualAttention(hidden_size=hidden_size,
num_heads=n_head,
batch_size=None,
src_seq_length=img_grid_size ** 2,
tgt_seq_length=img_grid_size ** 2,
hidden_dropout_rate=0.0,
attention_dropout_rate=0.0,
softmax_compute_type=softmax_compute_type,
use_past=False,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
parallel_config=parallel_config.dp_mp_config,
use_flash_attention=use_flash_attention,
use_attention_mask=False,
enable_fa_opt=enable_fa_opt)
self.ln_1 = LayerNorm((hidden_size,), eps=1e-6, param_init_type=param_init_type)
self.ln_1.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
self.mlp = MLP(layers=layers, hidden_size=hidden_size, intermediate_size=intermediate_size,
compute_dtype=compute_dtype, param_init_type=param_init_type, gelu_dtype=gelu_dtype,
parallel_config=parallel_config)
self.ln_2 = LayerNorm((hidden_size,), eps=1e-6, param_init_type=param_init_type)
self.ln_2.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
self.attn_mask = attn_mask
self.add = P.Add().shard(((dp, 1, mp), (dp, 1, mp)))
def construct(self, input_x: ms.Tensor):
r"""Construct"""
input_x = self.add(input_x, self.attention(self.ln_1(input_x)))
input_x = self.add(input_x, self.mlp(self.ln_2(input_x)))
return input_x
def attention(self, input_x: ms.Tensor):
r"""Attention"""
return self.attn(input_x, input_x, input_x, self.attn_mask)[0]
class QwenVLTransformer(nn.Cell):
r"""
Transformer of QwenVLVisionModel
Args:
image_size(int): Size of image.
patch_size(int): Patch size of image.
hidden_size (int): The dimension of input features.
intermediate_size(int): The linear width in MLP.
n_head (int): The number of attention heads.
layers (int): The number of transformer layers.
attn_mask (ms.Tensor): Attention mask.
dtype (mstype): The type of calculation, [mstype.float32, mstype.float16].
"""
def __init__(self, image_size: int, patch_size: int, hidden_size: int, intermediate_size: int, n_head: int,
layers: int,
dtype: mstype,
softmax_compute_type,
compute_dtype,
param_init_type,
gelu_dtype,
parallel_config,
use_flash_attention: bool = False,
enable_fa_opt: bool = False,
attn_mask: Optional[ms.Tensor] = None):
super().__init__()
self.resblocks = nn.SequentialCell(
*[ResidualAttentionBlock(image_size=image_size,
patch_size=patch_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
n_head=n_head,
layers=layers,
dtype=dtype,
softmax_compute_type=softmax_compute_type,
compute_dtype=compute_dtype,
param_init_type=param_init_type,
gelu_dtype=gelu_dtype,
parallel_config=parallel_config,
use_flash_attention=use_flash_attention,
enable_fa_opt=enable_fa_opt,
attn_mask=attn_mask)
for _ in range(layers)]
)
def construct(self, input_x):
r"""Construct"""
return self.resblocks(input_x)
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class QwenVL(PreTrainedModel):
"""
Provide QwenVL training loss or logits through network.
Args:
config (QwenVLConfig): The config of QwenVL model.
"""
@lazy_inline
def __init__(self, config: QwenVLConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.vision_encoder = build_network(config.vision_model)
self.llm_model = build_network(config.llm_model, default_args={"num_queries": self.config.num_queries})
self.image_start_id = self.config.image_start_id
self.image_pad_id = self.config.image_pad_id
self.num_queries = self.config.num_queries
self.image_size = self.config.vision_model.model_config.image_size
self.is_first_iteration = True
self.pad_token_id = config.pad_token_id
self.eos_token_id = config.eos_token_id
self.ignore_token_id = ms.Tensor(config.ignore_token_id, mstype.int32)
self.use_past = config.use_past
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
parallel_config = config.parallel_config
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
self.slice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),))
self.masked_fill = P.MaskedFill().shard(
((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1), ()))
self.tensor_scatter_update = ops.TensorScatterUpdate().shard(((1, 1, 1),
(1, 1, 1),
(1, 1, 1)))
self.gather = P.Gather().shard(((1, 1, 1), ()))
self.equal = P.Equal().shard(((parallel_config.data_parallel, 1), ()))
self.ones = P.Ones()
self.img_pos_add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
self.base_index_adder = None
self.freeze_component()
def freeze_component(self):
"""freeze components according to config"""
if self.config.freeze_vision:
logger.info("freeze vision encoder")
for param in self.vision_encoder.trainable_params():
if not self.config.freeze_resampler and "vision_encoder.attn_pool" in param.name:
param.requires_grad = True
else:
param.requires_grad = False
if self.config.freeze_llm:
logger.info("freeze llm model")
for param in self.llm_model.trainable_params():
param.requires_grad = False
def generate_base_index_adder(self, batch_size):
if self.base_index_adder is None:
self.base_index_adder = ms.Tensor(
[[i, 0] for i in range(batch_size)], ms.int32).reshape(batch_size, 1, 1, 2)
def update_model_kwargs_before_generate(self, input_ids, model_kwargs: dict):
batch_size, _ = input_ids.shape
self.generate_base_index_adder(batch_size)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""prepare inputs for generation in inference"""
batch_size, _ = input_ids.shape
is_first_iteration = self.is_first_iteration
slot_mapping = kwargs.get('slot_mapping')
if self.config.is_dynamic and "origin_inputs" in kwargs:
input_ids = kwargs.get("origin_inputs")
is_first_iteration = True
slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1))
if is_first_iteration or not self.use_past:
images = kwargs.pop("images")
img_pos = kwargs.pop("img_pos", None)
if img_pos is not None:
img_pos = ms.Tensor(img_pos, mstype.int32)
else:
img_shape = (batch_size, 3, 3, self.image_size, self.image_size)
images = self.ones(img_shape, ms.float32)
img_pos = self.ones((batch_size, 1, self.config.num_queries, 2), mstype.int32)
return {
"input_ids": ms.Tensor(input_ids, mstype.int32),
"images": images,
"img_pos": img_pos,
"slot_mapping": Tensor.from_numpy(slot_mapping)
}
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""prepare inputs for predict layout"""
input_ids = Tensor(input_ids, mstype.int32)
bs = input_ids.shape[0]
if "images" in kwargs:
images = Tensor(kwargs.get("images"))
else:
images = Tensor(np.random.random((bs, 1, 3, self.image_size, self.image_size)), ms.float32)
if "img_pos" in kwargs:
img_pos = Tensor(kwargs.get("img_pos"))
else:
img_pos = Tensor(np.random.randint(0, self.num_queries, (bs, 1, self.num_queries, 2)), ms.int32)
self.generate_base_index_adder(bs)
slot_mapping = Tensor(np.ones(shape=tuple([bs])), mstype.int32)
return input_ids, images, img_pos, None, None, None, None, None, None, None, None, None, slot_mapping
def set_dynamic_inputs(self):
"""set inputs when is_dynamic=True"""
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_images = Tensor(shape=[None, None, None, None, None], dtype=mstype.float32)
dynamic_img_pos = Tensor(shape=[None, None, None, None], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
self.set_inputs(dynamic_input_ids, dynamic_images, dynamic_img_pos, None, None, None,
None, None, dynamic_batch_valid_length, None, None,
dynamic_block_tables, dynamic_slot_mapping)
self.llm_model.set_dynamic_inputs()
logger.info("Set dynamic inputs for Qwen-VL")
def add_flags_custom(self, is_first_iteration):
"""add customized attributes for specific cells in the model."""
self.add_flags(is_first_iteration=is_first_iteration)
self.llm_model.add_flags_custom(is_first_iteration=is_first_iteration)
def kvcache(self, layer_idx):
return self.llm_model.kvcache(layer_idx)
def concat_image_text(self, text_embeds, image_embeds, img_pos):
"""update the value at a specific position of the text embedding with the image embedding"""
if self.training:
img_pos = img_pos.reshape((-1, self.num_queries, 2))
else:
img_pos = self.img_pos_add(img_pos, self.base_index_adder).reshape((-1, self.num_queries, 2))
image_embeds = self.cast(image_embeds, text_embeds.dtype)
text_embeds = self.tensor_scatter_update(text_embeds, img_pos, image_embeds)
return text_embeds
def construct(self, input_ids, images, img_pos: Tensor = None, labels=None,
input_position=None, position_ids=None, attention_mask=None, init_reset=None, batch_valid_length=None,
batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None):
"""forward of QwenVL"""
bs, seq_len = self.shape(input_ids)
if self.training:
tokens = self.slice(input_ids, (0, 0), (bs, seq_len - 1), (1, 1))
if labels is None:
pad_input_ids_pos = self.equal(input_ids, self.pad_token_id)
labels = self.masked_fill(input_ids, pad_input_ids_pos, self.ignore_token_id)
pad_label_pos = self.equal(labels, self.pad_token_id)
labels = self.masked_fill(labels, pad_label_pos, self.ignore_token_id)
else:
tokens = input_ids
input_embeds = self.llm_model.to_embeddings(tokens)
if attention_mask is None:
attention_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32)
if self.is_first_iteration or self.training:
if images.ndim == 5:
images_shape = self.shape(images)
new_shape = (images_shape[0] * images_shape[1], images_shape[2], images_shape[3], images_shape[4])
images = self.reshape(images, new_shape)
image_embeds = self.vision_encoder(images)
input_embeds = self.concat_image_text(input_embeds, image_embeds, img_pos)
return self.llm_model(
input_ids=None,
labels=labels,
input_position=input_position,
position_ids=position_ids,
attention_mask=attention_mask,
input_embeds=input_embeds,
init_reset=init_reset,
batch_valid_length=batch_valid_length,
batch_index=batch_index,
zactivate_len=zactivate_len,
block_tables=block_tables,
slot_mapping=slot_mapping
)