import os
import torch
import torch.distributed as dist
from typing import Any, Type
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
from mindspeed_llm.fsdp2.models.model_registry import ModelRegistry
from mindspeed_llm.fsdp2.distributed.mindspeed_parallel_engine import MindSpeedParallelEngine
from mindspeed_llm.fsdp2.distributed.parallel_engine_config import (
ParallelEngineConfig,
FSDPPlanConfig,
TPPlanConfig,
EPPlanConfig,
CPPlanConfig,
QuantizeConfig
)
from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.models.model_loader import ModelLoader
logger = get_logger(__name__)
class ModelFactory:
"""
Responsible for building HuggingFace native models and wrapping them
as MindSpeed FSDP instances based on parallelization arguments.
Supports two initialization modes controlled by model_args.init_model_with_meta_device:
- False: Load model fully on CPU (original behavior)
- True: Create empty model on meta device, load weights after FSDP wrapping
"""
@staticmethod
def create(model_args, parallel_args):
"""
Creates a MindSpeed FSDP wrapped model.
Args:
model_args: Contains model_name_or_path, trust_remote_code, train_from_scratch,
init_model_with_meta_device, etc.
parallel_args: Contains tp_size, fsdp_size, recompute, ep_size, etc.
"""
if torch.accelerator.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
target_device = torch.device(torch.accelerator.current_accelerator().type, local_rank)
torch.accelerator.set_device(target_device)
else:
target_device = torch.device("cpu")
use_meta_device = getattr(model_args, 'init_model_with_meta_device', False)
init_device = "meta" if use_meta_device else "cpu"
logger.info_rank0(f"> Model initialization device: {init_device} (init_model_with_meta_device={use_meta_device})")
logger.info_rank0(f"> Loading AutoConfig from {model_args.model_name_or_path}...")
trust_remote_code = model_args.trust_remote_code
hf_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=trust_remote_code
)
model_cls = None
if getattr(model_args, 'model_id', None):
logger.info_rank0(f"> Using factory mode with model_id: {model_args.model_id}")
model_cls = ModelRegistry.get_model_class(model_args.model_id)
if hasattr(model_cls, 'register_patches'):
model_cls.register_patches(model_args)
loader = ModelLoader(model_args, init_device=init_device)
model, weights_path = loader.create_model(model_cls=model_cls)
world_size = dist.get_world_size() if dist.is_initialized() else 1
tp_size = parallel_args.tp_size
fsdp_size = parallel_args.fsdp_size
cp_size = parallel_args.cp_size
dp_size = world_size // (tp_size * fsdp_size * cp_size)
parallel_config = ModelFactory._build_parallel_config(model_args, parallel_args, dp_size)
logger.info_rank0(f"> Wrapping model with MindSpeed FSDP (TP={tp_size},CP={cp_size}, FSDP={fsdp_size})...")
model = MindSpeedParallelEngine(
config=parallel_config,
model=model,
init_device=init_device,
weights_path=weights_path
)
if init_device == "cpu":
model = model.to(target_device)
return model
@staticmethod
def _build_parallel_config(model_args, parallel_args, dp_size) -> 'ParallelEngineConfig':
"""
Builds the Config based on parallel_args and hardcoded layer name rules.
Note: The wildcards here (e.g., 'model.layers.{*}') are suitable for standard structures like Llama/Qwen.
If using other non-standard models, these strings might need adjustment.
"""
apply_modules = {
parallel_args.fsdp_modules[0]: {'reshard_after_forward': parallel_args.reshard_after_forward,
'shard_placement_fn': parallel_args.shard_placement_fn},
}
for modules in parallel_args.fsdp_modules[1:]:
apply_modules[modules] = {'reshard_after_forward': parallel_args.reshard_after_forward,}
fsdp_plan = FSDPPlanConfig(
ignored_modules=parallel_args.ignored_modules if parallel_args.ignored_modules else [],
apply_modules= apply_modules,
param_dtype=parallel_args.param_dtype,
reduce_dtype=parallel_args.reduce_dtype,
num_to_forward_prefetch=parallel_args.num_to_forward_prefetch,
num_to_backward_prefetch=parallel_args.num_to_backward_prefetch
)
tp_plan = TPPlanConfig(
colwise_parallel=parallel_args.tp_colwise,
rowwise_parallel=parallel_args.tp_rowwise
)
ep_size = parallel_args.ep_size
ep_fsdp_size = parallel_args.ep_fsdp_size
ep_plan = EPPlanConfig(
apply_modules=parallel_args.ep_modules,
apply_efsdp_modules=parallel_args.ep_fsdp_modules,
dispatcher=parallel_args.ep_dispatcher,
)
cp_plan = CPPlanConfig(
context_parallel_type=parallel_args.cp_type,
is_pack=getattr(model_args, "pack", False)
)
recompute_plan = parallel_args.recompute_modules if parallel_args.recompute else []
quantization_plan = QuantizeConfig(
recipe_name=model_args.quant_recipe_name,
apply_modules=model_args.quant_apply_modules,
ignored_modules=model_args.quant_ignored_modules,
quant_converters=model_args.quant_converters,
enable_fsdp_low_precision_all_gather=model_args.enable_fsdp_low_precision_all_gather,
fsdp_low_precision_all_gather_mode=model_args.fsdp_low_precision_all_gather_mode,
)
tp_size = parallel_args.tp_size
fsdp_size = parallel_args.fsdp_size
config = ParallelEngineConfig(
data_parallel_size=dp_size,
fully_shard_parallel_size=fsdp_size,
fsdp_plan=fsdp_plan,
tensor_parallel_size=tp_size,
tp_plan=tp_plan,
expert_parallel_size=ep_size,
expert_fully_shard_parallel_size=ep_fsdp_size,
expert_data_parallel_size=dp_size,
ep_plan=ep_plan,
context_parallel_size=parallel_args.cp_size,
context_parallel_type=parallel_args.cp_type,
cp_plan=cp_plan,
recompute=parallel_args.recompute,
recompute_plan=recompute_plan,
quantization_plan = quantization_plan
)
return config
class AutoModelFactory:
"""
Unified Factory for creating models.
Dispatches to ModelFactory or McoreModelFactory based on the environment.
"""
@staticmethod
def create(*args, **kwargs):
"""
Factory method that forwards arguments to the specific implementation.
"""
return ModelFactory.create(*args, **kwargs)