import torch
from megatron.training import get_args
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
from megatron.core.extensions.transformer_engine import TENorm
from mindspeed_mm.models.common.module import MultiModalModule
class MultimodalProjector(MultiModalModule):
"""
MultimodalProjector will take the encoded input with input_size hidden state and project
it into the hidden size of the language model for multimodal training. When projectors is
type affine linear_fc1 form submodules is used.
Args:
transformer_config (TransformerConfig): Transformer config.
submodules (MLPSubmodules): Specifies MLP submodules for mlp type projector.
projector_type (str): Projector type.
input_size (int): Input size from feature encoder.
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
use_postshuffle_norm: bool = False,
):
super().__init__(config=config)
self.projector_type = config.model_id
self.use_postshuffle_norm = use_postshuffle_norm
if submodules is None:
raise AssertionError("MLPSubmodules must be provided")
if self.projector_type == "mlp":
self.encoder = MLP(
config=config,
submodules=submodules,
input_size=config.input_size
)
elif self.projector_type == "affine":
self.encoder = build_module(
submodules.linear_fc1,
config.input_size,
config.hidden_size,
config=config,
init_method=config.init_method,
gather_output=True,
bias=config.add_bias_linear,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name=None,
)
elif self.projector_type == "lnmlp":
self.ffn_hidden_size = config.ffn_hidden_size
self.layernorm = TENorm(
config=config,
hidden_size=self.ffn_hidden_size if use_postshuffle_norm else config.input_size,
eps=config.layernorm_epsilon)
self.encoder = MLP(
config=config,
submodules=submodules,
input_size=config.ffn_hidden_size
)
else:
raise Exception(f"Unsupported multimodal projection type {self.projector_type}")
def forward(self, hidden_states):
if self.projector_type == "lnmlp":
if self.use_postshuffle_norm:
hidden_states = hidden_states.view(-1, self.ffn_hidden_size)
hidden_states = self.layernorm(hidden_states).view(-1, self.ffn_hidden_size)
encoder_output, encoder_output_bias = self.encoder(hidden_states)
if encoder_output_bias is not None:
encoder_output = encoder_output + encoder_output_bias
encoder_output = make_viewless_tensor(
inp=encoder_output,
requires_grad=True,
keep_graph=True,
)
return encoder_output