import gc
import os
from typing import Any, Dict, Optional
import logging
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint import (
FileSystemReader,
FileSystemWriter,
)
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.stateful import Stateful
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.utils.device import empty_cache, synchronize
from mindspeed_mm.fsdp.checkpoint.checkpointer import CheckpointerBase
from mindspeed_mm.fsdp.checkpoint.utils import (
get_checkpoint_name,
read_metadata,
get_checkpoint_tracker_filename,
remove_base_layer_keys,
restore_base_layer_keys,
)
from mindspeed_mm.fsdp.checkpoint.load_utils import rank0_load_and_broadcast_weights
logger = logging.getLogger(__name__)
_EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt"
_EXTRA_STATE_DIR = "extra_state"
class ModelState(Stateful):
"""
A wrapper around a model to make it stateful.
Args:
model (Model): model to wrap.
"""
def __init__(self, model):
self.model = model
self.parallel_state = get_parallel_state()
@torch.no_grad()
def state_dict(self):
model_state_dict = get_model_state_dict(model=self.model)
return model_state_dict
@torch.no_grad()
def load_state_dict(self, state_dict):
"""
perform the reverse operation for state_dict()
need to drop EP-dim when loading from DCP checkpoints
so that EP-FSDP would not be confused
"""
set_model_state_dict(model=self.model, model_state_dict=state_dict)
class LoraModelState(ModelState):
"""
A wrapper around a lora model to make it stateful.
Args:
model (Model): lora model to wrap.
"""
def __init__(self, model):
super().__init__(model)
self.key_mapping = None
@torch.no_grad()
def state_dict(self):
model_state_dict = get_model_state_dict(model=self.model)
self.key_mapping = remove_base_layer_keys(model_state_dict)
return model_state_dict
@torch.no_grad()
def load_state_dict(self, state_dict):
restore_base_layer_keys(state_dict, self.key_mapping)
set_model_state_dict(model=self.model, model_state_dict=state_dict)
class OptimizerState(Stateful):
"""
A wrapper around an optimizer to make it stateful.
Args:
optimizer (Optimizer): optimizer to wrap.
"""
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
self.parallel_state = get_parallel_state()
self.should_ep_aware = getattr(self.optimizer, "_is_multi_optimizer", False)
def state_dict(self):
if self.should_ep_aware:
if not self.optimizer._is_multi_optimizer:
raise ValueError("EP is enabled but optimizer is not a MultiOptimizer instance")
return self.optimizer.state_dict()
sd = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer)
return sd
def load_state_dict(self, state_dict):
optim_state_from_dcp_load = state_dict
if self.should_ep_aware:
self.optimizer.load_state_dict(optim_state_from_dcp_load)
return
set_optimizer_state_dict(
model=self.model,
optimizers=self.optimizer,
optim_state_dict=optim_state_from_dcp_load,
)
class DistributedCheckpointer(CheckpointerBase):
"""
Distributed checkpointer for torch.distributed.checkpoint
"""
dcp_save_future: Optional[Any] = None
_async_process_group: Optional[Any] = None
@classmethod
def save(
cls,
path: str,
state: Dict[str, Any],
save_async: bool = False,
iteration: int = None,
storage_writer: Optional[FileSystemWriter] = None,
) -> None:
"""
save training state to distributed checkpoint
args:
path: path to save checkpoint
state: state to save
save_async: whether to save asynchronously
iteration: global steps
storage_writer: storage writer backend for dcp.save and dcp.async_save. If None, will use FileSystemWriter
return:
None
"""
if "model" not in state:
raise ValueError("Model must be provided to save a distributed checkpoint.")
checkpoint_dir = get_checkpoint_name(path, iteration, release=False)
cls._create_checkpoint_dir(checkpoint_dir)
cls._save_extra_state(checkpoint_dir=checkpoint_dir, state=state)
save_state = {"model": ModelState(state["model"])}
if "optimizer" in state:
save_state["optimizer"] = OptimizerState(model=state["model"], optimizer=state["optimizer"])
if storage_writer is None:
storage_writer = cls._create_storage_writer(checkpoint_dir)
cls.execute_save(save_state=save_state, storage_writer=storage_writer, save_async=save_async)
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(path)
with open(tracker_filename, 'w', encoding='utf-8') as f:
f.write(str(iteration))
print_rank(logger.info, f"Saved checkpoint to {checkpoint_dir}")
@classmethod
def load(
cls,
path: str,
state: Dict[str, Any],
process_group=None,
storage_reader: Optional[FileSystemReader] = None,
load_rank0_and_broadcast: bool = False,
load_strict: bool = False,
enable_lora: bool = False,
) -> bool:
"""
load training state from distributed checkpoint
args:
path: path to load checkpoint
state: state to load, "model" are required, "optimizer" and "extra_state" are optional
process_group: process group for loading checkpoint
storage_reader: storage reader backend for dcp.load. If None, will use FileSystemReader
enable_lora: whether to enable LoRA checkpoint loading logic
return:
release (bool): whether the loaded checkpoint is a "release" checkpoint
"""
checkpoint_dir = path
iteration, release = -1, False
tracker_filename = get_checkpoint_tracker_filename(checkpoint_dir)
if os.path.isfile(tracker_filename):
iteration, release = read_metadata(tracker_filename)
checkpoint_dir = get_checkpoint_name(checkpoint_dir, iteration, release)
if state is None:
raise ValueError("State dict must be provided to load a distributed checkpoint.")
if "model" not in state:
raise ValueError("Model must be provided to load a distributed checkpoint.")
load_state = {"model": LoraModelState(state["model"])} if enable_lora else {"model": ModelState(state["model"])}
if not release and "optimizer" in state:
load_state["optimizer"] = OptimizerState(model=state["model"], optimizer=state["optimizer"])
if storage_reader is None:
storage_reader = cls._create_storage_reader(checkpoint_dir)
if load_rank0_and_broadcast:
rank0_load_and_broadcast_weights(
load_state=load_state,
storage_reader=storage_reader,
)
else:
dcp.load(
state_dict=load_state,
storage_reader=storage_reader,
process_group=process_group,
planner=DefaultLoadPlanner(allow_partial_load=not load_strict),
)
if not release:
cls._load_extra_state(checkpoint_dir=checkpoint_dir, state=state)
print_rank(logger.info, f"Loaded checkpoint from {checkpoint_dir}")
return release
@classmethod
def execute_save(
cls,
save_state: Dict[str, Any],
storage_writer: FileSystemWriter,
save_async: bool,
) -> None:
"""Execute DCP save with optional async support."""
if save_async:
if cls._async_process_group is None:
cls._async_process_group = dist.new_group(backend="gloo")
if cls.dcp_save_future is not None:
logger.info("[RANK %s] waiting for previous DCP saving session to end...", dist.get_rank())
cls.dcp_save_future.result()
cls.dcp_save_future = None
dist.barrier()
cls.dcp_save_future = dcp.async_save(
state_dict=save_state,
storage_writer=storage_writer,
process_group=cls._async_process_group,
)
else:
dcp.save(
state_dict=save_state,
storage_writer=storage_writer,
)
if dist.is_initialized():
dist.barrier()
gc.collect()
empty_cache()
synchronize()
@classmethod
def _create_checkpoint_dir(cls, checkpoint_dir: str) -> None:
"""Create checkpoint directory."""
os.makedirs(checkpoint_dir, exist_ok=True)
@classmethod
def _create_storage_reader(cls, checkpoint_dir: str) -> FileSystemReader:
"""Create storage reader for DCP."""
return FileSystemReader(checkpoint_dir)
@classmethod
def _create_storage_writer(cls, checkpoint_dir: str) -> FileSystemWriter:
"""Create storage writer for DCP."""
return FileSystemWriter(
checkpoint_dir,
thread_count=16,
single_file_per_rank=True,
sync_files=False,
)
@classmethod
def _save_extra_state(cls, checkpoint_dir: str, state: Dict[str, Any]) -> None:
"""Save extra_state to checkpoint directory."""
if "extra_state" not in state:
logger.warning("extra_state not found in state, skipping extra_state save")
return
extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
os.makedirs(extra_state_dir, exist_ok=True)
extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
torch.save(
state["extra_state"],
extra_state_path,
)
@classmethod
def _load_extra_state(cls, checkpoint_dir: str, state: Dict[str, Any]) -> None:
"""Load extra_state from checkpoint directory."""
if "extra_state" not in state:
logger.warning("extra_state not found in state, skipping extra_state load")
return
extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
os.makedirs(extra_state_dir, exist_ok=True)
extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
state["extra_state"] = torch.load(extra_state_path, weights_only=False)