"""layers for visualglm"""
import mindspore as ms
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindformers.models.base_config import BaseConfig
class ImageTextEmbeddingConcat(nn.Cell):
"""
Layer to concat image embedding and text embedding
"""
def __init__(self, pad_token_id):
super().__init__()
self.concat_2d = P.Concat(axis=1)
self.concat_3d = P.Concat(axis=1)
self.not_equal = P.NotEqual()
self.ones = P.Ones()
self.cast = P.Cast()
self.pad_token_id = pad_token_id
def construct(self, image_embeddings: ms.Tensor, pre_text_embeddings: ms.Tensor, post_text_embeddings: ms.Tensor):
pre_text_embeddings = self.cast(pre_text_embeddings, mstype.float32)
post_text_embeddings = self.cast(post_text_embeddings, mstype.float32)
concat_embeds = self.concat_3d([pre_text_embeddings, image_embeddings, post_text_embeddings])
return concat_embeds
class ImageTextEmbeddingPreparationMixIn:
"""
image text embemdding mixin
"""
def __init__(self, config: BaseConfig):
"""init method"""
self.image_text_concat = ImageTextEmbeddingConcat(config.get("pad_token_id", 3))
def to_text_embeddings(self, text_input_ids):
raise NotImplementedError
def prepare_image_text_embedding(self, input_ids, **kwargs):
""" prepare image and text embeddings """
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
input_position = kwargs.get("current_index", None)
if input_position is not None:
input_position = ms.Tensor(input_position, mstype.int32)
concat_inputs_embeds = None
if self.is_first_iteration or not self.use_past:
image_embeddings = kwargs.get("image_embeds")
pre_input_ids = ms.Tensor(kwargs.get("pre_input_ids"), mstype.int32)
image_embeddings_length = image_embeddings.shape[1]
pre_text_embeddings_length = pre_input_ids.shape[1]
post_input_ids = ms.Tensor(input_ids[:, image_embeddings_length + pre_text_embeddings_length:],
mstype.int32)
pre_text_embeddings = self.to_text_embeddings(pre_input_ids)
post_text_embeddings = self.to_text_embeddings(post_input_ids)
concat_inputs_embeds = self.image_text_concat(image_embeddings, pre_text_embeddings, post_text_embeddings)
return {
"input_ids": ms.Tensor(input_ids, mstype.int32),
"input_embeddings": concat_inputs_embeds,
"attention_mask": ms.Tensor(attention_mask, mstype.int32),
"position_ids": ms.Tensor(position_ids, mstype.int32),
"input_position": input_position
}