import json
import logging
import os
import re
from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Set, Tuple
import torch
import torch.distributed as dist
from safetensors import safe_open
from tqdm import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.checkpoint.load_utils import ParamInfo
from mindspeed_mm.fsdp.checkpoint.utils import remove_base_layer_keys
from mindspeed_mm.fsdp.utils.device import empty_cache, get_device_type
from mindspeed_mm.fsdp.utils.utils import tensor_to_dtensor
logger = logging.getLogger(__name__)
@dataclass
class HFWeightFileStream:
"""Yields ``(key, full_tensor)`` one tensor at a time from a safetensors file.
Uses ``safe_open`` so only the current tensor is materialized in host
memory; the consumer copies it into the model and drops it before the next
is read. Peak host memory is therefore about one tensor, not the whole
shard file.
"""
filepath: str
def __iter__(self) -> Generator[Tuple[str, torch.Tensor], None, None]:
with safe_open(self.filepath, framework="pt", device="cpu") as f:
for key in f.keys():
yield key, f.get_tensor(key)
def locate_hf_weight_files(weights_path: str) -> List[HFWeightFileStream]:
"""Resolve the safetensors files under *weights_path* into stream readers.
Supports both standard HF layouts:
- single ``model.safetensors``
- sharded ``model-*-of-*.safetensors`` described by ``model.safetensors.index.json``
"""
single = os.path.join(weights_path, SAFE_WEIGHTS_NAME)
if os.path.isfile(single):
return [HFWeightFileStream(single)]
index = os.path.join(weights_path, SAFE_WEIGHTS_INDEX_NAME)
if os.path.isfile(index):
with open(index, "r", encoding="utf-8") as f:
weight_map = json.load(f)["weight_map"]
streams = []
for name in sorted(set(weight_map.values())):
path = os.path.join(weights_path, name)
if not os.path.isfile(path):
raise FileNotFoundError(
f"Shard '{name}' referenced by {index} is missing: {path}"
)
streams.append(HFWeightFileStream(path))
return streams
raise ValueError(f"No HF safetensors weights found under {weights_path}.")
def looks_like_hf_weight_dir(path: Optional[str]) -> bool:
"""Whether *path* is a directory holding HF safetensors weights.
Used by ``load_format='auto'`` to tell an HF directory apart from a DCP
checkpoint directory -- the latter has ``latest_checkpointed_iteration.txt``
and ``.distcp`` files instead, never the standard HF weight filenames.
"""
if not path or not os.path.isdir(path):
return False
return os.path.isfile(os.path.join(path, SAFE_WEIGHTS_NAME)) or os.path.isfile(
os.path.join(path, SAFE_WEIGHTS_INDEX_NAME)
)
def _resolve_leaf(model: torch.nn.Module, name: str) -> Tuple[torch.nn.Module, str]:
"""Walk the dotted FQN *name* down to the leaf module that owns it.
For ``model.language_model.layers.0.self_attn.q_proj.weight`` it returns
the ``q_proj`` Linear module and ``"weight"``.
"""
module = model
pieces = name.split(".")
for piece in pieces[:-1]:
if not hasattr(module, piece):
raise ValueError(f"Cannot resolve '{name}': submodule '{piece}' not found.")
module = getattr(module, piece)
return module, pieces[-1]
def convert_weight_key(key: str, model: torch.nn.Module) -> str:
mapping = getattr(model, "_checkpoint_conversion_mapping", None)
if not mapping:
return key
for pattern, replacement in mapping.items():
replacement = re.sub(r"\(.*\)", "", replacement.lstrip("^"))
new_key, n_subs = re.subn(pattern, replacement, key)
if n_subs > 0:
return new_key
return key
def write_full_tensor(model: torch.nn.Module, name: str, full_tensor: torch.Tensor) -> None:
"""Write *full_tensor* into the model at *name* (parameter or buffer).
For a sharded parameter (a DTensor created by ``fully_shard``), the target
already carries its ``device_mesh`` / ``placements``; we move the full
tensor onto the mesh device and let ``tensor_to_dtensor`` carve out this
rank's shard (a local Replicate -> Shard redistribute, no comm). For a
plain tensor (replicated parameter or buffer), we just move dtype/device
and copy. The write is always in place since FSDP2 holds the parameter
object.
"""
leaf_module, local_name = _resolve_leaf(model, name)
if local_name in leaf_module._parameters:
target = leaf_module._parameters[local_name].data
elif local_name in leaf_module._buffers:
target = leaf_module._buffers[local_name]
else:
raise ValueError(f"'{name}' is neither a parameter nor a buffer of the model.")
if hasattr(target, "device_mesh"):
full_tensor = full_tensor.to(device=target.to_local().device, dtype=target.dtype)
shard = tensor_to_dtensor(full_tensor, target.device_mesh, target.placements)
target.copy_(shard)
else:
target.copy_(full_tensor.to(device=target.device, dtype=target.dtype))
def _lora_base_key_map(param_names: Set[str]) -> Dict[str, str]:
"""Ref DCP lora base key map
"""
base_to_bare = remove_base_layer_keys({name: None for name in param_names})
return {bare: base for base, bare in base_to_bare.items()}
def _retie_embeddings(model: torch.nn.Module) -> None:
"""Re-tie input/output embeddings when the config requests it.
- ``to_empty_if_needed`` broke any shared storage; this restores it.
- AND across ``model.config`` and ``text_config`` covers nested multimodal cases.
- Object-reference assignment so both modules share one nn.Parameter.
"""
config = getattr(model, "config", None)
if config is None:
return
text_config = (
config.get_text_config(decoder=True) if hasattr(config, "get_text_config") else config
)
should_tie = (
(hasattr(config, "tie_word_embeddings") or hasattr(text_config, "tie_word_embeddings"))
and getattr(config, "tie_word_embeddings", True)
and getattr(text_config, "tie_word_embeddings", True)
)
if not should_tie:
return
try:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
if input_embeddings is None or output_embeddings is None:
return
output_embeddings._parameters["weight"] = input_embeddings._parameters["weight"]
except Exception as e:
raise RuntimeError("Failed to tie input/output embeddings after HF load") from e
def _log_unexpected_keys(unexpected_keys: Set[str]) -> None:
if not unexpected_keys:
return
samples = sorted(unexpected_keys)[:5]
suffix = "" if len(unexpected_keys) <= 5 else f" (showing 5 of {len(unexpected_keys)})"
print_rank(
logger.info,
f"HF checkpoint had {len(unexpected_keys)} key(s) not present in the model. Examples{suffix}: {samples}",
)
def post_process_after_load(
model: torch.nn.Module,
missing_param_keys: set,
load_strict: bool = False,
) -> None:
missing = {k for k in missing_param_keys if "lora_" not in k}
if missing:
if load_strict:
raise RuntimeError(
f"{len(missing)} parameter key(s) absent from the HF checkpoint "
f"(load_strict=True): {sorted(missing)}"
)
logger.warning(
"%d parameter key(s) absent from the HF checkpoint and left "
"uninitialized: %s. Training will likely produce NaNs. "
"Pre-load these via the HF file or extend post_process_after_load "
"to call _init_weights on them.",
len(missing),
sorted(missing),
)
_retie_embeddings(model)
@torch.no_grad()
def load_hf_weights(
model: torch.nn.Module, hf_dir: str, enable_lora: bool = False, load_strict: bool = False
) -> None:
"""Load HF safetensors weights into *model* directly, no offline conversion.
Every rank opens the same files and slices each full tensor into its own
DTensor shard locally -- no cross-rank communication. Suitable for small/
medium models or shared filesystems; for very large models use the
rank0-read-and-broadcast variant (added in a follow-up).
Assumes ``model`` has already been laid out by ``fully_shard`` and brought
out of meta via ``to_empty_if_needed`` -- i.e. parameter slots are empty
DTensors on the real device and buffer values have been preserved.
"""
param_names = {name for name, _ in model.named_parameters()}
buffer_names = {name for name, _ in model.named_buffers()}
lora_base_map = _lora_base_key_map(param_names) if enable_lora else {}
unexpected_keys: Set[str] = set()
for shard_stream in locate_hf_weight_files(hf_dir):
for raw_key, full_tensor in shard_stream:
key = convert_weight_key(raw_key, model)
key = lora_base_map.get(key, key)
if key in param_names:
param_names.discard(key)
write_full_tensor(model, key, full_tensor)
elif key in buffer_names:
write_full_tensor(model, key, full_tensor)
else:
unexpected_keys.add(key)
_log_unexpected_keys(unexpected_keys)
post_process_after_load(model, missing_param_keys=param_names, load_strict=load_strict)
def load_hf_checkpoint(
model: torch.nn.Module,
hf_dir: str,
*,
load_rank0_and_broadcast: bool = False,
enable_lora: bool = False,
load_strict: bool = False,
) -> bool:
cfg = {}
cfg_path = os.path.join(hf_dir, "config.json")
if os.path.isfile(cfg_path):
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
tcfg = cfg.get("text_config", cfg)
method = "rank0-broadcast" if load_rank0_and_broadcast else "every-rank-read"
print_rank(
logger.info,
f"Loading HF safetensors -> FSDP DTensors via {method} (online): "
f"dir={hf_dir} "
f"arch={(cfg.get('architectures') or ['?'])[0]} "
f"layers={tcfg.get('num_hidden_layers', '?')} "
f"tie_emb={cfg.get('tie_word_embeddings', tcfg.get('tie_word_embeddings', '?'))}",
)
if load_rank0_and_broadcast:
rank0_load_and_broadcast_hf_weights(model, hf_dir, enable_lora=enable_lora, load_strict=load_strict)
else:
load_hf_weights(model, hf_dir, enable_lora=enable_lora, load_strict=load_strict)
return True
@torch.no_grad()
def rank0_load_and_broadcast_hf_weights(
model: torch.nn.Module, hf_dir: str, enable_lora: bool = False, load_strict: bool = False
) -> None:
"""Load HF safetensors weights via rank0 read + ``dist.broadcast``.
Mirrors the structure of the project's DCP rank0 broadcast loader
(``load_utils.py``): rank0 opens each safetensors shard, broadcasts the
per-shard ``param_info_list`` in one shot, then per-tensor broadcasts the
tensor data. Every rank then runs the same dispatch as ``load_hf_weights``
once it holds the full tensor. Total disk I/O is one read of the HF
weights (vs ``world_size`` reads for ``load_hf_weights``), at the cost of
cross-rank communication.
Assumes ``model`` has been laid out by ``fully_shard`` and brought out of
meta via ``to_empty_if_needed`` in ``get_model``.
"""
rank0 = dist.get_rank() == 0
torch_device = torch.device(get_device_type())
param_names = {name for name, _ in model.named_parameters()}
buffer_names = {name for name, _ in model.named_buffers()}
lora_base_map = _lora_base_key_map(param_names) if enable_lora else {}
unexpected_keys: Set[str] = set()
if rank0:
shard_paths = [s.filepath for s in locate_hf_weight_files(hf_dir)]
else:
shard_paths = []
shard_count_tensor = torch.tensor(
[len(shard_paths)] if rank0 else [0],
dtype=torch.int64,
device=torch_device,
)
dist.broadcast(shard_count_tensor, src=0)
shard_count = int(shard_count_tensor.item())
shard_iterable = tqdm(
range(shard_count),
desc="Loading HF checkpoint shards",
disable=int(os.getenv("LOCAL_RANK", "-1")) > 0,
)
for shard_id in shard_iterable:
if rank0:
shard_state: Dict[str, torch.Tensor] = {}
with safe_open(shard_paths[shard_id], framework="pt", device="cpu") as f:
for key in f.keys():
shard_state[key] = f.get_tensor(key)
param_info_list = [
ParamInfo(name=k, shape=v.shape, dtype=v.dtype)
for k, v in shard_state.items()
]
else:
shard_state = {}
param_info_list = []
broadcast_list = [param_info_list]
dist.broadcast_object_list(broadcast_list, src=0)
param_info_list = broadcast_list[0]
for info in param_info_list:
key = convert_weight_key(info.name, model)
key = lora_base_map.get(key, key)
if rank0:
tensor = shard_state[info.name].to(torch_device, non_blocking=True)
else:
tensor = torch.empty(info.shape, dtype=info.dtype, device=torch_device)
dist.broadcast(tensor, src=0)
if key in param_names:
param_names.discard(key)
write_full_tensor(model, key, tensor)
elif key in buffer_names:
write_full_tensor(model, key, tensor)
else:
unexpected_keys.add(key)
del tensor
del shard_state
empty_cache()
_log_unexpected_keys(unexpected_keys)
post_process_after_load(model, missing_param_keys=param_names, load_strict=load_strict)