import os
import importlib
import logging
import torch
import torch.distributed as dist
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from accelerate import init_empty_weights
from mindspeed.fsdp.utils.str_match import module_name_match
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.params.model_args import ModelArguments
from mindspeed_mm.fsdp.params.feature_args import FeatureArguments
from mindspeed_mm.fsdp.params.training_args import TrainingArguments
from mindspeed_mm.fsdp.utils.register import model_register
from mindspeed_mm.fsdp.models.base_model import BaseModel
logger = logging.getLogger(__name__)
class ModelHub:
"""
Responsible for building HuggingFace native models.
"""
@staticmethod
def _build_custom_model(model_args: ModelArguments, training_args: TrainingArguments) -> BaseModel:
model_id = getattr(model_args, "model_id", None)
if model_id:
model_cls = model_register.get(model_id)
else:
raise ValueError("`model_id` must be provided in model_args when using custom models.")
if model_cls is None:
raise ValueError(f"model_id '{model_id}' is not registered in MODEL_MAPPINGS. ")
if training_args.init_model_with_meta_device:
with init_empty_weights():
model = model_cls._from_config(model_args).float()
for m in model.modules():
if getattr(m, "_is_hf_initialized", False):
m._is_hf_initialized = False
else:
model = model_cls.from_pretrained(model_args).float()
return model
@staticmethod
def _build_transformers_model(transformer_config: PretrainedConfig, model_args: ModelArguments,
feature_args: FeatureArguments, training_args: TrainingArguments) -> PreTrainedModel:
architectures = getattr(transformer_config, "architectures", [])
model_cls = None
model_id = getattr(model_args, "model_id", None)
if model_id:
model_cls = model_register.get(model_id)
elif architectures:
transformers_module = importlib.import_module("transformers")
model_cls = getattr(transformers_module, architectures[0], None)
if model_cls is None:
raise ValueError("load model from config failed")
if callable(getattr(model_cls, 'overwrite_transformer_config', None)):
transformer_config = model_cls.overwrite_transformer_config(transformer_config, model_args, feature_args)
if training_args.init_model_with_meta_device:
with init_empty_weights():
model = model_cls._from_config(transformer_config).float()
for m in model.modules():
if getattr(m, "_is_hf_initialized", False):
m._is_hf_initialized = False
else:
model = model_cls.from_pretrained(
model_args.model_name_or_path,
config=transformer_config,
dtype=torch.float32,
low_cpu_mem_usage=True,
device_map="cpu",
trust_remote_code=model_args.trust_remote_code
)
return model
@staticmethod
def build(model_args: ModelArguments, feature_args: FeatureArguments, training_args: TrainingArguments):
"""
Build a model instance from HuggingFace based on model arguments and training configuration.
Args:
model_args: Contains model_name_or_path, trust_remote_code, etc.
training_args: Contains training configuration like init_model_with_meta_device, etc.
Returns:
Configured model instance ready for training.
"""
try:
print_rank(logger.info, f"> Loading AutoConfig from {model_args.model_name_or_path}...")
transformer_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
_attn_implementation=model_args.attn_implementation
)
except Exception as e:
first_line = next(iter(str(e).splitlines()), "")
_MAX_ERROR_MSG_LEN = 200
msg = first_line[:_MAX_ERROR_MSG_LEN] + "..." if len(first_line) > _MAX_ERROR_MSG_LEN else first_line
logger.warning(
f"AutoConfig.from_pretrained failed for '{model_args.model_name_or_path}' "
f"({type(e).__name__}: {msg}); falling back to custom model builder. "
f"If you intended to load a HuggingFace model, check the error above."
)
transformer_config = None
if transformer_config:
print_rank(logger.info, f"Building transformers model from configuration...")
model: PreTrainedModel = ModelHub._build_transformers_model(transformer_config, model_args, feature_args,
training_args)
else:
print_rank(logger.info, f"Building custom model...")
model: BaseModel = ModelHub._build_custom_model(model_args, training_args)
freezed_named_modules = []
if len(model_args.freeze) > 0:
for name, module in model.named_modules():
for pattern in model_args.freeze:
if module_name_match(pattern, name):
freezed_named_modules.append((name, module))
for name, module in freezed_named_modules:
print_rank(logger.info, f"freezing module {name}...")
for param in module.parameters():
param.requires_grad_(False)
return model