from typing import Dict, Union
import torch
from torch import nn
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from mindspeed_mm.models.ae.base import AEModel
from mindspeed_mm.models.transformer.transformer_model import TransformerModel
from mindspeed_mm.models.vision.vision_model import VisionModel
class OmniModel(nn.Module):
"""
OmniModel - A unified multi-modal model that can incorporate vision encoder,
multi-modal large language model (MLLM), and autoencoder components.
This model supports flexible configurations for different multi-modal tasks,
allowing selective freezing of components during training.
"""
def __init__(self, config):
"""Initialize the OmniModel with configuration.
Args:
config: Configuration object containing model architecture settings
"""
super().__init__()
self.config = core_transformer_config_from_args(get_args())
self.add_image_encoder = config.image_encoder is not None
self.add_mllm = config.mllm is not None
self.add_ae = config.ae is not None
if self.add_image_encoder:
self.image_encoder = VisionModel(config.image_encoder)
if self.add_mllm:
self.mllm = TransformerModel(config).get_model()
if self.add_ae:
self.ae = AEModel(config.ae).get_model()
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
def freeze(
self,
freeze_mllm: bool = False,
freeze_image_encoder: bool = False,
freeze_image_projection: bool = False,
freeze_ae: bool = True,
):
if self.add_image_encoder:
self.image_encoder.freeze(freeze_image_encoder, freeze_image_projection)
if self.add_mllm and freeze_mllm:
for param in self.mllm.parameters():
param.requires_grad = False
if self.add_ae and freeze_ae:
for param in self.ae.parameters():
param.requires_grad = False
def forward(self, inputs) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
"""Forward pass through the OmniModel.
Args:
inputs: Dictionary containing input tensors for different modalities
Returns:
Model outputs, which could be a dictionary of tensors or a single tensor
"""
if self.add_image_encoder and "vit_input_embeds" not in inputs:
vit_input_embeds = self.image_encoder(**inputs)
inputs["vit_inputs_embeds"] = vit_input_embeds
if self.add_ae and "padded_latent" not in inputs:
padded_latent = self.ae.encode(**inputs)
inputs["padded_latent"] = padded_latent
outputs = self.mllm(**inputs)
return outputs