"""
Internvl Model
"""
import numpy as np
import mindspore as ms
from mindspore import nn, Tensor
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindformers.tools import logger
from mindformers import MindFormerRegister, MindFormerModuleType
from mindformers.models import build_network
from mindformers.models.utils import lazy_inline
from mindformers.modules.layers import LayerNorm, Linear
from mindformers.models.multi_modal.base_model import BaseXModalToTextModel
from research.internvl2.internvl_configuration import InternVLChatConfig
class MLP(nn.Cell):
r"""
Module to resize the image dimension into the llm model dimension
Args:
config (LlavaConfig): The config of Llava model.
"""
def __init__(self, config: InternVLChatConfig):
super(MLP, self).__init__(config)
vision_config = config.vision_model.model_config
text_config = config.text_model.model_config
self.layer_norm = LayerNorm(vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2)
self.adapter1 = Linear(
in_channels=vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2,
out_channels=text_config.hidden_size,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type
)
self.activation_func = P.GeLU()
self.adapter2 = Linear(
in_channels=text_config.hidden_size,
out_channels=text_config.hidden_size,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type
)
dp = config.parallel_config.data_parallel
mp = config.parallel_config.model_parallel
self.adapter1.shard(((dp, 1), (mp, 1)), strategy_bias=((dp, mp), (mp,)))
self.adapter2.shard(((dp, mp), (1, mp)), strategy_bias=((dp, 1), (1,)))
self.activation_func.shard(((dp, 1, mp),))
def construct(self, x):
"""adapter forward method"""
output = self.layer_norm(x)
output = self.adapter1(output)
output = self.activation_func(output)
output = self.adapter2(output)
return output
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class InternVLChatModel(BaseXModalToTextModel):
"""
Provide InternVL2 training loss or logits through network.
Args:
config (InternVLChatConfig): The config of InternVL model.
"""
config_class = InternVLChatConfig
main_input_name = 'pixel_values'
_supports_flash_attn_2 = True
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
@lazy_inline
def __init__(self, config: InternVLChatConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.vision_config = config.vision_model.model_config
self.text_config = config.text_model.model_config
image_size = config.force_image_size or self.vision_config.image_size
patch_size = self.vision_config.patch_size
dp = config.parallel_config.data_parallel
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = config.template
self.image_size = image_size
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.select_layer = config.select_layer
self.num_queries = config.num_queries
self.base_index_adder = None
self.tensor_scatter_update = P.TensorScatterUpdate()
self.expand_dims = P.ExpandDims()
self.vision_model = build_network(config.vision_model)
self.language_model = build_network(config.text_model)
self.img_context_token_id = config.img_context_token_id
self.img_pos_add = P.Add().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
self.vit_hidden_size = self.vision_config.hidden_size
self.text_hidden_size = self.text_config.hidden_size
self.mlp1 = MLP(config)
self.is_first_iteration = True
self.not_equal = P.NotEqual().shard(((dp, 1), ()))
self.add = P.Add().shard(((dp, 1), ()))
self.shape = P.Shape()
self.reshape = P.Reshape().add_prim_attr("skip_redistribution", True)
self.cast = P.Cast()
self.slice = P.StridedSlice().shard(((dp, 1),))
self.ones = P.Ones()
self.gather = P.Gather(1).shard(((dp, 1, 1), (dp,)))
self.prefill_gather_flatten = P.Gather().shard(((dp, 1, 1), (dp,)))
self.sub_batch_valid_len = P.Sub().shard(((1,), ()))
self.is_first_iteration = True
self.transpose_4dim = P.Transpose().shard(((dp, 1, 1, 1),))
self.slice_3dim = P.StridedSlice().shard(((dp, 1, 1),))
def set_dynamic_inputs(self):
"""set inputs when is_dynamic=True"""
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_images = Tensor(shape=[None, None, None, None], dtype=mstype.float32)
dynamic_img_pos = Tensor(shape=[None, self.num_queries, 2], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
self.set_inputs(dynamic_input_ids, dynamic_images, dynamic_img_pos, None, None, None,
None, None, dynamic_batch_valid_length, None, None,
dynamic_block_tables, dynamic_slot_mapping)
logger.info("Set dynamic inputs for Internvl2")
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""prepare inputs for predict layout"""
input_ids = Tensor(input_ids, mstype.int32)
bs, seq = input_ids.shape
if "images" in kwargs:
images = Tensor(kwargs.get("images"))
else:
images = Tensor(np.random.random((13, 3, self.image_size, self.image_size)), ms.float32)
if "image_context_pos" in kwargs:
image_context_pos = Tensor(kwargs.get("image_context_pos"))
else:
image_context_pos = Tensor(np.ones((13, self.num_queries, 2)), ms.int32)
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
return input_ids, images, image_context_pos, None, None, None, None, None, None, None, None, None, slot_mapping
def _prepare_inputs_for_prefill_flatten(self, input_ids, batch_valid_length, slot_mapping, model_inputs):
"""prepare inputs ids for prefill flatten"""
batch_valid_length_bs = batch_valid_length.shape[0]
input_ids_list = []
for i in range(batch_valid_length_bs):
context_len = batch_valid_length[i]
input_ids_list.append(input_ids[i][:context_len])
input_ids = np.concatenate(input_ids_list, 0)
input_ids = input_ids.reshape((1, -1))
slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1))
model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32))
model_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping)
return model_inputs
def prepare_inputs_for_generation(self, input_ids, **kwargs):
"""prepare inputs for generation in inference"""
if self.config.is_dynamic and "origin_inputs" in kwargs:
input_ids = kwargs.get("origin_inputs")
images = kwargs.pop("images")
slot_mapping = kwargs.pop("slot_mapping")
image_context_pos = kwargs.pop("image_context_pos", None)
if image_context_pos is not None and not isinstance(image_context_pos, ms.Tensor):
image_context_pos = ms.Tensor(image_context_pos, mstype.int32)
if slot_mapping is not None:
model_inputs = {
"input_ids": ms.Tensor(input_ids, mstype.int32),
"images": images,
"image_context_pos": image_context_pos,
"slot_mapping": Tensor.from_numpy(slot_mapping)
}
if self.is_first_iteration:
batch_valid_length = kwargs.get("valid_length_each_example")
model_inputs = self._prepare_inputs_for_prefill_flatten(
input_ids, batch_valid_length, slot_mapping, model_inputs
)
else:
model_inputs = {
"input_ids": ms.Tensor(input_ids, mstype.int32),
"images": images,
"image_context_pos": image_context_pos
}
return model_inputs
def pixel_shuffle(self, x, scale_factor=0.5):
"""shuffle pixel tensor"""
n, w, h, c = self.shape(x)
new_h = int(h * scale_factor)
new_c = int(c / scale_factor)
x = self.reshape(x, (n, w, new_h, new_c))
x = self.transpose_4dim(x, (0, 2, 1, 3))
new_w = int(w * scale_factor)
new_c = int(c / (scale_factor * scale_factor))
x = self.reshape(x, (n, new_h, new_w, new_c))
x = self.transpose_4dim(x, (0, 2, 1, 3))
return x
def extract_feature(self, pixel_values):
"""get image embeds"""
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
select_layer=self.select_layer)
old_vit_embeds_shape = self.shape(vit_embeds)
vit_embeds = self.slice_3dim(vit_embeds,
(0, 1, 0),
(old_vit_embeds_shape[0], old_vit_embeds_shape[1], old_vit_embeds_shape[2]),
(1, 1, 1))
vit_embeds_shape = self.shape(vit_embeds)
h = int(vit_embeds_shape[1] ** 0.5)
w = int(vit_embeds_shape[1] ** 0.5)
vit_embeds = self.reshape(vit_embeds, (vit_embeds_shape[0], w, h, vit_embeds_shape[2]))
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds_shape_2 = self.shape(vit_embeds)
vit_embeds = self.reshape(
vit_embeds, (vit_embeds_shape_2[0], vit_embeds_shape_2[1] * vit_embeds_shape_2[2], vit_embeds_shape_2[3]))
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def construct(self, input_ids, images, image_context_pos: Tensor = None, labels=None, input_position=None,
position_ids=None, attention_mask=None, init_reset=None, batch_valid_length=None, batch_index=None,
zactivate_len=None, block_tables=None, slot_mapping=None):
"""forward of InternVL"""
tokens = input_ids
input_embeds = self.language_model.to_embeddings(tokens)
if self.is_first_iteration:
image_embeds = self.extract_feature(images)
input_embeds = self.update_modal_to_text(image_embeds, input_embeds, image_context_pos)
return self.language_model(
input_ids=input_ids,
labels=labels,
input_position=input_position,
position_ids=position_ids,
attention_mask=attention_mask,
input_embeds=input_embeds,
slot_mapping=slot_mapping,
batch_valid_length=batch_valid_length,
block_tables=block_tables,
init_reset=init_reset,
batch_index=batch_index,
zactivate_len=zactivate_len
)