import os
import glob
import json
from contextlib import contextmanager
from typing import Optional, Dict, Tuple, Set
import torch
import torch.nn as nn
from torch.distributed.tensor import distribute_tensor
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors import safe_open
try:
from transformers.modeling_utils import no_init_weights
except ImportError:
from transformers.initialization import no_init_weights
from mindspeed_llm.fsdp2.utils.logging import get_logger
from mindspeed_llm.fsdp2.utils.global_vars import get_args
logger = get_logger(__name__)
@contextmanager
def init_empty_weights():
"""
A context manager under which models are initialized with all parameters on the meta device.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = (
param
if param.device == torch.device("meta")
else param_cls(module._parameters[name].to("meta"), **kwargs)
)
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
nn.Module.register_parameter = old_register_parameter
def reset_hf_initialized_flag(module: nn.Module) -> None:
"""Reset HuggingFace's _is_hf_initialized flag."""
if hasattr(module, "_is_hf_initialized"):
setattr(module, "_is_hf_initialized", False)
for child in module.children():
reset_hf_initialized_flag(child)
def _find_submodule(module: nn.Module, name: str) -> Tuple[nn.Module, str]:
"""Find the leaf module according to the name."""
pieces = name.split(".")
for piece in pieces[:-1]:
module = getattr(module, piece)
return module, pieces[-1]
class ModelLoader:
"""Load model on CPU or meta device."""
def __init__(self, model_args, init_device: str = "cpu"):
self.model_args = model_args
self.init_device = init_device
self.trust_remote_code = getattr(model_args, 'trust_remote_code', False)
self.model_path = model_args.model_name_or_path
self.train_from_scratch = getattr(model_args, 'train_from_scratch', False)
self.hf_config = None
def load_config(self) -> AutoConfig:
"""Load HuggingFace model config."""
logger.info_rank0(f"> Loading config from {self.model_path}...")
self.hf_config = AutoConfig.from_pretrained(
self.model_path,
trust_remote_code=self.trust_remote_code,
attn_implementation="eager" if get_args().cp_size > 1 else None,
)
return self.hf_config
def create_model(self, model_cls=None) -> Tuple[nn.Module, Optional[str]]:
"""Create model based on init_device."""
if self.hf_config is None:
self.load_config()
if self.init_device == "meta":
return self._create_on_meta(model_cls)
else:
return self._create_on_cpu(model_cls)
def _create_on_cpu(self, model_cls=None) -> Tuple[nn.Module, None]:
"""Create and load model on CPU."""
if model_cls is not None:
logger.info_rank0(f"> Loading {model_cls.__name__} on CPU...")
model = model_cls.from_pretrained(
self.model_path,
config=self.hf_config,
low_cpu_mem_usage=True,
device_map="cpu",
torch_dtype=torch.float32
)
elif self.train_from_scratch:
logger.info_rank0("> Creating model with random weights on CPU...")
model = AutoModelForCausalLM.from_config(
self.hf_config,
trust_remote_code=self.trust_remote_code,
torch_dtype=torch.float32
)
else:
logger.info_rank0(f"> Loading pretrained model on CPU from {self.model_path}...")
model = AutoModelForCausalLM.from_pretrained(
self.model_path,
config=self.hf_config,
trust_remote_code=self.trust_remote_code,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map="cpu"
)
return model, None
def _create_on_meta(self, model_cls=None) -> Tuple[nn.Module, Optional[str]]:
"""Create empty model on meta device."""
weights_path = None if self.train_from_scratch else self.model_path
if model_cls is not None:
logger.info_rank0(f"> Creating empty {model_cls.__name__} on meta device...")
with init_empty_weights(), no_init_weights():
if hasattr(model_cls, '_from_config'):
model = model_cls._from_config(self.hf_config)
else:
model = model_cls.from_config(self.hf_config)
elif self.train_from_scratch:
logger.info_rank0("> Creating empty model on meta device for random init...")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
self.hf_config,
trust_remote_code=self.trust_remote_code,
torch_dtype=torch.float32
)
else:
logger.info_rank0(f"> Creating empty model on meta device (weights: {self.model_path})...")
with init_empty_weights(), no_init_weights():
model = AutoModelForCausalLM.from_config(
self.hf_config,
trust_remote_code=self.trust_remote_code,
torch_dtype=torch.float32
)
logger.info_rank0(f"> Model created on meta device. Weights path: {weights_path}")
return model, weights_path
class WeightLoader:
"""
Load weights into FSDP-wrapped model.
"""
@staticmethod
def load(
model: nn.Module,
weights_path: Optional[str],
device: Optional[str] = None
) -> None:
"""Load or initialize weights after FSDP wrapping."""
if device is None:
device = torch.accelerator.current_accelerator().type
if weights_path is None:
WeightLoader._init_random(model, device)
else:
WeightLoader._load_pretrained(model, weights_path, device)
@staticmethod
def _init_random(model: nn.Module, device: str) -> None:
"""Initialize model with random weights."""
logger.info_rank0(f"> Initializing random weights on {device}...")
model.to_empty(device=device)
model = model.float()
reset_hf_initialized_flag(model)
if hasattr(model, 'init_weights'):
model.init_weights()
logger.info_rank0("> Random initialization done")
@staticmethod
@torch.no_grad()
def _load_pretrained(model: nn.Module, weights_path: str, device: str) -> None:
"""
Load pretrained weights.
"""
logger.info_rank0(f"> Loading pretrained weights from {weights_path}...")
buffer_dict = {name: buffer.clone() for name, buffer in model.named_buffers()}
parameter_names_to_load = {name for name, _ in model.named_parameters()}
logger.info_rank0(f"> Saved {len(buffer_dict)} buffers, {len(parameter_names_to_load)} parameters to load")
model.to_empty(device=device)
model = model.float()
logger.info_rank0(f"> Model materialized to {device}")
state_dict_files = WeightLoader._get_state_dict_files(weights_path)
for state_dict_file in state_dict_files:
for name, tensor in WeightLoader._iterate_state_dict(state_dict_file):
if name in buffer_dict:
buffer_dict[name] = tensor.clone()
elif name in parameter_names_to_load:
parameter_names_to_load.remove(name)
WeightLoader._dispatch_parameter(model, name, tensor)
else:
logger.debug(f"> Unexpected key in state dict: {name}")
WeightLoader._post_process(model, buffer_dict, parameter_names_to_load, distribute_tensor)
if torch.distributed.is_initialized():
torch.distributed.barrier()
logger.info_rank0("> Pretrained weights loaded successfully")
@staticmethod
def _get_state_dict_files(weights_path: str):
"""Get list of state dict files."""
index_file = os.path.join(weights_path, "model.safetensors.index.json")
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index = json.load(f)
files = set(index["weight_map"].values())
return [os.path.join(weights_path, f) for f in sorted(files)]
single_safetensor = os.path.join(weights_path, "model.safetensors")
if os.path.exists(single_safetensor):
return [single_safetensor]
safetensor_files = sorted(glob.glob(os.path.join(weights_path, "*.safetensors")))
if safetensor_files:
return safetensor_files
pytorch_files = sorted(glob.glob(os.path.join(weights_path, "*.bin")))
if pytorch_files:
return pytorch_files
pytorch_files = sorted(glob.glob(os.path.join(weights_path, "*.pt")))
if pytorch_files:
return pytorch_files
raise FileNotFoundError(f"No weight files found in {weights_path}")
@staticmethod
def _iterate_state_dict(filepath: str):
"""Iterate over state dict file, yielding (key, tensor) pairs."""
if filepath.endswith(".safetensors"):
with safe_open(filepath, framework="pt", device="cpu") as f:
for key in f.keys():
yield key, f.get_tensor(key)
else:
state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
for key, tensor in state_dict.items():
yield key, tensor
@staticmethod
def _dispatch_parameter(
model: nn.Module,
name: str,
tensor: torch.Tensor,
) -> None:
"""
Copy a checkpoint tensor into the target parameter of a parallelized model.
"""
module, local_name = _find_submodule(model, name)
orig_param = dict(module.named_parameters(recurse=False))[local_name]
orig_tensor = orig_param.data
if not hasattr(orig_tensor, "_local_tensor"):
src = tensor.to(device=orig_tensor.device, dtype=orig_tensor.dtype, non_blocking=True)
orig_param.data.copy_(src, non_blocking=True)
return
mesh = orig_tensor.device_mesh
placements = orig_tensor.placements
local_target = orig_tensor._local_tensor
local_cpu = tensor
for mesh_dim, placement in enumerate(placements):
if not placement.is_shard():
continue
shard_dim = placement.dim
mesh_size = mesh.size(mesh_dim)
mesh_rank = mesh.get_local_rank(mesh_dim)
full_len = local_cpu.size(shard_dim)
if full_len % mesh_size == 0:
chunk_len = full_len // mesh_size
local_cpu = local_cpu.narrow(shard_dim, mesh_rank * chunk_len, chunk_len)
else:
chunks = torch.chunk(local_cpu, mesh_size, dim=shard_dim)
if mesh_rank < len(chunks):
local_cpu = chunks[mesh_rank]
else:
empty_shape = list(local_cpu.shape)
empty_shape[shard_dim] = 0
local_cpu = local_cpu.new_empty(empty_shape)
local_cpu = local_cpu.contiguous().pin_memory()
local_dev = local_cpu.to(
device=local_target.device,
dtype=local_target.dtype,
non_blocking=True,
)
local_target.copy_(local_dev, non_blocking=True)
@staticmethod
def _dispatch_buffer(
model: nn.Module,
name: str,
buffer: torch.Tensor,
dtensor_factory
) -> None:
"""
Assign buffer to model.
"""
module, local_name = _find_submodule(model, name)
orig_buffer = dict(module.named_buffers(recurse=False))[local_name]
if hasattr(orig_buffer, "device_mesh"):
device_mesh = orig_buffer.device_mesh
placements = orig_buffer.placements
module.register_buffer(local_name, dtensor_factory(
buffer.to(dtype=orig_buffer.dtype),
device_mesh,
placements
))
else:
dict(module.named_buffers(recurse=False))[local_name].copy_(buffer.to(device=orig_buffer.device, dtype=orig_buffer.dtype))
@staticmethod
def _init_parameter(model: nn.Module, name: str) -> None:
"""
Initialize missing parameter.
"""
pieces = name.split(".")
init_func = None
module = model
for piece in pieces[:-1]:
if hasattr(module, "_init_weights"):
init_func = getattr(module, "_init_weights", None)
module = getattr(module, piece)
if init_func is not None:
module.apply(init_func)
else:
logger.warning(f"> Cannot find _init_weights for {name}, skipping initialization")
@staticmethod
def _post_process(
model: nn.Module,
buffer_dict: Dict[str, torch.Tensor],
parameter_names_left: Set[str],
dtensor_factory
) -> None:
"""
Post-process after weight loading.
"""
for name, buffer in buffer_dict.items():
try:
WeightLoader._dispatch_buffer(model, name, buffer, dtensor_factory)
except Exception as e:
logger.warning(f"> Failed to restore buffer {name}: {e}")
logger.info_rank0(f"> Restored {len(buffer_dict)} buffers")
if parameter_names_left:
logger.info_rank0(f"> Missing {parameter_names_left} parameters, initializing them...")
for name in parameter_names_left:
try:
WeightLoader._init_parameter(model, name)
except Exception as e:
logger.warning(f"> Failed to initialize {name}: {e}")
if getattr(model.config, "tie_word_embeddings", True):
try:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None and input_embeddings is not None:
output_embeddings.register_parameter(
"weight",
input_embeddings.weight,
)
logger.info_rank0("> Tied input/output embeddings")
except Exception as e:
logger.warning(f"> Failed to tie embeddings: {e}")