"""
This module provides utility functions for PyTorch Distributed Checkpoint (DCP), including device helpers, memory management, state_dict sharding, serialization, checkpoint cleanup, and DCP-to-torch state_dict conversion.
"""
import gc
import os
import torch
from collections import OrderedDict
from functools import lru_cache
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union, Tuple
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from safetensors.torch import save_file
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed._tensor import DeviceMesh, DTensor, Shard, Replicate
from mindspeed_llm.fsdp2.distributed.expert_parallel.expert_parallel import get_ep_modules
from mindspeed_llm.fsdp2.utils.logging import get_logger
logger = get_logger(__name__)
@lru_cache
def get_dtype_size(dtype: "torch.dtype") -> int:
"""
Return the size (in bytes) of a given torch dtype.
This implementation is adapted from safetensors to ensure
consistent size calculation across serialization backends.
Args:
dtype (torch.dtype): Torch data type
Returns:
int: Size in bytes for a single element
"""
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
_SIZE = {
torch.int64: 8,
torch.float32: 4,
torch.int32: 4,
torch.bfloat16: 2,
torch.float16: 2,
torch.int16: 2,
torch.uint8: 1,
torch.int8: 1,
torch.bool: 1,
torch.float64: 8,
_float8_e4m3fn: 1,
_float8_e5m2: 1,
}
return _SIZE[dtype]
def synchronize() -> None:
"""
Synchronize the current device stream.
"""
torch.accelerator.synchronize()
def empty_cache() -> None:
"""
Explicitly release cached device memory and trigger garbage collection.
"""
gc.collect()
torch.accelerator.empty_cache()
def get_shard_info(
state_dict: Dict[str, "torch.Tensor"],
save_dtype: Optional[Union[str, "torch.dtype"]],
shard_size: int,
safe_serialization: bool,
) -> Tuple[bool, int, Dict[str, str]]:
"""
Compute sharding information for a state_dict.
This function determines:
- Whether weights need to be sharded
- Total serialized size
- Mapping from parameter name to shard file name
Args:
state_dict (Dict[str, Tensor]): Model state_dict
save_dtype (str or torch.dtype): Target dtype for saving
shard_size (int): Maximum shard size in bytes
safe_serialization (bool): Whether to use safetensors format
Returns:
Tuple:
- is_sharded (bool)
- total_size (int)
- weight_map (Dict[str, str])
"""
current_size, total_size = 0, 0
current_shard, shard_list = [], []
for name, tensor in state_dict.items():
if isinstance(save_dtype, str):
dtype = getattr(torch, save_dtype)
elif isinstance(save_dtype, torch.dtype):
dtype = save_dtype
else:
dtype = tensor.dtype
tensor_size = tensor.numel() * get_dtype_size(dtype)
if current_size != 0 and current_size + tensor_size > shard_size:
total_size += current_size
shard_list.append(current_shard)
current_size = 0
current_shard = []
current_size += tensor_size
current_shard.append(name)
if current_size != 0:
total_size += current_size
shard_list.append(current_shard)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
num_shards = len(shard_list)
weight_map = OrderedDict()
if num_shards == 1:
is_sharded = False
for name in shard_list[0]:
weight_map[name] = weights_name
else:
is_sharded = True
for shard_idx, shard in enumerate(shard_list):
prefix, extension = weights_name.rsplit(".", maxsplit=1)
file_name = f"{prefix}-{shard_idx + 1:05d}-of-{num_shards:05d}.{extension}"
for name in shard:
weight_map[name] = file_name
return is_sharded, total_size, weight_map
def save_state_dict(
state_dict: Dict[str, "torch.Tensor"],
path_to_save: "os.PathLike",
safe_serialization: bool,
) -> None:
"""
Save a state_dict to disk.
Args:
state_dict (Dict[str, Tensor]): State dictionary to save
path_to_save (PathLike): Output file path
safe_serialization (bool): Whether to use safetensors
"""
if safe_serialization:
save_file(state_dict, path_to_save, metadata={"format": "pt"})
else:
torch.save(state_dict, path_to_save)
def cleanup_old_checkpoints(training_args, shared_file_system: bool = True):
"""
Remove old checkpoints based on save_total_limit.
Checkpoints are sorted by modification time (oldest first).
For shared storage, only rank 0 performs deletion.
For non-shared storage, each node's local_rank 0 deletes its own local copy.
All ranks must call this function together.
Args:
training_args: Training arguments containing output_dir and save_total_limit
shared_file_system (bool): Whether nodes share the same filesystem.
"""
if not hasattr(training_args, "save_total_limit"):
return
save_total_limit = training_args.save_total_limit
if save_total_limit is None or save_total_limit <= 0:
return
rank = torch.distributed.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
should_delete = (rank == 0) if shared_file_system else (local_rank == 0)
if should_delete:
checkpoint_dir = os.path.join(training_args.output_dir, "checkpoints")
if os.path.isdir(checkpoint_dir):
checkpoints = []
for item in os.listdir(checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, item)
if os.path.isdir(checkpoint_path):
mtime = os.path.getmtime(checkpoint_path)
checkpoints.append((mtime, checkpoint_path))
checkpoints.sort(key=lambda x: x[0])
if len(checkpoints) > save_total_limit:
for _, checkpoint_path in checkpoints[:-save_total_limit]:
try:
if os.path.isdir(checkpoint_path):
if rank == 0:
logger.info_rank0(f"Removing old checkpoint: {checkpoint_path}")
import shutil
shutil.rmtree(checkpoint_path)
except FileNotFoundError:
pass
torch.distributed.barrier()
def drop_ep_dim(loaded_tensor: torch.Tensor, device_mesh: DeviceMesh) -> torch.Tensor:
"""Drop the Expert Parallelism (EP) dimension from a tensor loaded from a DCP checkpoint.
Args:
loaded_tensor (torch.Tensor | DTensor): The tensor loaded from a DCP checkpoint,
potentially containing an extra EP dimension.
device_mesh (DeviceMesh): A 2D device mesh with dimensions ("ep", "ep_fsdp").
Returns:
torch.Tensor | DTensor: The tensor with the EP dimension removed, ready to be
copied into model parameters.
"""
if device_mesh.ndim != 2:
raise ValueError(f"device_mesh.ndim must be 2, got {device_mesh.ndim}")
ep_fsdp_mesh = device_mesh["ep_fsdp"]
if isinstance(loaded_tensor, DTensor):
if len(loaded_tensor.placements) == 2:
tensor_to_put = DTensor.from_local(
loaded_tensor.to_local(), device_mesh=ep_fsdp_mesh, placements=[Shard(1)]
)
elif len(loaded_tensor.placements) == 1:
tensor_to_put = loaded_tensor.to_local()
else:
raise RuntimeError(
f"Expect EP parameters to be DTensor with 1 or 2 placements, got {loaded_tensor.placements}"
)
else:
tensor_to_put = loaded_tensor
return tensor_to_put
def restore_ep_dim(origin_tensor: torch.Tensor, device_mesh: DeviceMesh) -> DTensor:
"""Restore the Expert Parallelism (EP) dimension on a tensor for DCP checkpoint saving.
Args:
origin_tensor (torch.Tensor | DTensor): The model tensor to be saved, in its current EP-sharded layout.
device_mesh (DeviceMesh): A 2D device mesh with dimensions ("ep", "ep_fsdp").
Returns:
DTensor: The tensor with the EP dimension restored, suitable for DCP saving.
"""
if device_mesh.ndim != 2:
raise ValueError(f"device_mesh.ndim must be 2, got {device_mesh.ndim}")
ep_mesh = device_mesh["ep"]
if isinstance(origin_tensor, DTensor):
dtensor = DTensor.from_local(
origin_tensor.to_local(), device_mesh=device_mesh, placements=[Shard(0), Shard(1)]
)
elif torch.is_tensor(origin_tensor):
dtensor = DTensor.from_local(origin_tensor, device_mesh=ep_mesh, placements=[Shard(0)])
else:
raise RuntimeError(f"origin_tensor - {origin_tensor} is not a tensor!")
return dtensor
@dataclass
class EPSpecInfo:
"""Specification for how an Expert Parallelism (EP) parameter should be
saved/restored as a DTensor during checkpointing.
Attributes:
placement (Union[Shard, Replicate]): The DTensor placement strategy for the EP dimension (typically Shard(0) for expert-parallel parameters).
ep_fsdp_mesh (DeviceMesh): A 2D device mesh with dimensions ("ep", "ep_fsdp") used to construct the DTensor layout during save/load.
"""
placement: Union[Shard, Replicate]
ep_fsdp_mesh: DeviceMesh
def build_ep_fqn2spec_info(
model, parallel_state, ep_plan
) -> Dict[str, EPSpecInfo]:
"""Build a mapping from parameter fully-qualified names (FQNs) to their EP spec info.
Args:
model: The model instance. Must support `named_modules()` and `named_parameters()`. Its inner `model.model` attribute is passed
to `get_ep_modules` to resolve expert modules.
parallel_state (ParallelState): The current parallelism configuration, used to query EP and E-FSDP group sizes.
ep_plan: The expert parallelism plan (from `model.config.ep_plan`)
that defines which modules are expert-parallel.
Returns:
Dict[str, EPSpecInfo]: A dictionary mapping each expert parameter's FQN
to its `EPSpecInfo`, containing the placement strategy and the 2D
("ep", "ep_fsdp") device mesh for checkpoint transformations.
"""
ep_size = parallel_state.get_ep_group_size()
efsdp_size = (
parallel_state.get_efsdp_group_size()
if parallel_state.is_efsdp_enable() else 1
)
ep_fsdp_mesh = DeviceMesh(
device_type=torch.accelerator.current_accelerator().type,
mesh=torch.arange(ep_size * efsdp_size).view(ep_size, efsdp_size),
mesh_dim_names=("ep", "ep_fsdp"),
)
ep_modules = get_ep_modules(model.model, ep_plan)
ep_module_fqns = {
name for name, module in model.named_modules()
if module in ep_modules
}
fqn2spec = {}
for fqn, _ in model.named_parameters():
if any(fqn == m or fqn.startswith(m + ".") for m in ep_module_fqns):
fqn2spec[fqn] = EPSpecInfo(
placement=Shard(0), ep_fsdp_mesh=ep_fsdp_mesh
)
return fqn2spec