from transformers import Mistral3ForConditionalGeneration
from megatron.training import print_rank_0
from mindspeed_mm.models.transformers.base_model import FSDP2Mixin, WeightInitMixin
from mindspeed_mm.models.transformers.custom_model_registry import register_model
from mindspeed_mm.models.transformers.mistral3.modules import MMMistralAttention
from mindspeed_mm.models.transformers.mistral3.modeling_mistral import MMMistralModel
@register_model("mistral3")
class MultiModelMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration, FSDP2Mixin, WeightInitMixin):
def __init__(self, config):
super().__init__(config)
self.model.language_model = MMMistralModel(config=config.text_config)
for idx, layer in enumerate(self.model.language_model.layers):
layer.self_attn = MMMistralAttention(config=config.text_config, layer_idx=idx)
def freeze(self, config):
forbidden_modules = set()
if config.vision_encoder.freeze:
vision_model_keys = ["vision_tower"]
print_rank_0(f"Set vision model not trainable: {vision_model_keys}")
forbidden_modules.update(vision_model_keys)
if config.vision_projector.freeze:
projector_keys = ["multi_modal_projector"]
print_rank_0(f"Set vision model not trainable: {projector_keys}")
forbidden_modules.update(projector_keys)
if config.text_decoder.freeze:
language_model_keys = ["language_model", "lm_head"]
print_rank_0(f"Set vision model not trainable: {language_model_keys}")
forbidden_modules.update(language_model_keys)
for name, param in self.model.named_parameters():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
param.requires_grad_(False)