from typing import Optional, Union
import torch
from megatron.core.extensions.transformer_engine import TENorm
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.training import get_args
from mindspeed_mm.models.common.module import MultiModalModule
from mindspeed_mm.utils.utils import get_device
from .vision_transformer_block import VisionTransformerBlock
class CLIPViT(MultiModalModule):
"""
CLIP ViT vision model.
Instantiate a CLIP Vit model.
Args:
transformer_config (TransformerConfig): Transformer config.
transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers.
ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre.
add_class_token (bool, optional): Include a class token. Defaults to True.
class_token_len (int): Class token length. Defaults to 1 but 8 may be faster.
patch_size (int): Image patch size.
image_size (int): Input image size.
"""
def __init__(
self,
config: TransformerConfig,
transformer_layer_spec: ModuleSpec,
pre_process: bool = True,
post_process: bool = True,
*args,
**kwargs,
) -> None:
super().__init__(config=config)
self.device = get_device(config.device)
self.class_token_len = config.class_token_len
self.visual_hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.img_h = config.image_size
self.img_w = config.image_size
self.use_flash_attn = getattr(get_args(), "use_flash_attn", False)
if self.img_h % self.patch_size != 0:
raise AssertionError("patch_size should be an exact divisor of img_height")
if self.img_w % self.patch_size != 0:
raise AssertionError("patch_size should be an exact divisor of img_width")
self.num_patches_per_dim_h = self.img_h // self.patch_size
self.num_patches_per_dim_w = self.img_w // self.patch_size
self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
self.add_class_token = config.add_class_token
self.class_token_len = config.class_token_len
self.seq_length = self.num_patches + (self.class_token_len if self.add_class_token else 0)
self.conv1 = torch.nn.Conv2d(
in_channels=3,
out_channels=self.visual_hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).to(self.device)
self.position_embeddings = torch.nn.Embedding(self.seq_length, self.visual_hidden_size)
if self.add_class_token:
self.class_token = torch.nn.Parameter(
torch.randn(1, self.class_token_len, self.visual_hidden_size)
)
self.ln_pre = TENorm(
config=self.config,
hidden_size=self.visual_hidden_size,
)
self.model_type = ModelType.encoder_or_decoder
self.decoder = VisionTransformerBlock(
config=config,
spec=transformer_layer_spec,
pre_process=True,
post_process=False,
)
def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
"""
Sets pinput tensor to the model.
Args:
input_tensor (torch.Tensor):Sets the input tensor for the model.
"""
self.decoder.set_input_tensor(input_tensor)
def forward(
self,
pixel_values: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
*args,
**kwargs
) -> torch.Tensor:
"""
Forward function of the CLIP ViT Model. This function passes the input tensors
through the embedding layer and then the transformer.
Args:
pixel_values (torch.Tensor): Input data of shape [batch, img_h, img_w]
attention_mask (torch.Tensor with dtype=bool): Attention mask to use. If none, all ones.
Returns:
x (torch.Tensor): output after final transformer block of shape [b, s, h].
"""
x = self.conv1(pixel_values)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
if self.add_class_token:
class_token = self.class_token.expand(
x.shape[0], -1, -1
)
x = torch.cat(
[class_token, x], dim=1
)
if x.shape[1] != self.seq_length:
raise AssertionError(f"{x.shape[1]} != {self.seq_length}")
x = x + self.position_embeddings(self.position_ids)
x = self.ln_pre(x)
x = x.permute(1, 0, 2).contiguous()
if attention_mask is None and not self.use_flash_attn:
attention_mask = torch.ones(
1, 1, self.seq_length, self.seq_length
).to(self.device)
attention_mask = attention_mask < 0.5
x = self.decoder(x, attention_mask)
x = x.permute(1, 0, 2).contiguous()
x = x[:, 1:]
return x