from typing import Optional, List, cast
import inspect
import warnings
import dataclasses
import os
from functools import reduce
import torch
from torch import Tensor
from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE
from torch.distributed.checkpoint.storage import StorageWriter, StorageReader
from torch.distributed.checkpoint.planner import SavePlanner, LoadPlanner, ReadItem, LoadItemType, LoadPlan
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.logger import _dcp_method_logger
from torch.distributed.checkpoint.filesystem import _StoragePrefix, _StorageInfo
from torch.distributed.checkpoint import FileSystemReader
def partial_save_dcp_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
planner: Optional[SavePlanner] = None,
part_idx: int = 0
):
"""
Save a partial shard of a Distributed Checkpoint (DCP) state_dict.
This function enables a single process (e.g., in a single-machine or testing context)
to save its portion of model weights as torch_dcp format. It coordinates between
a SavePlanner (which decides how tensors are laid out) and a StorageWriter (which
handles actual I/O). The function returns global metadata (typically populated only
by a coordinator rank in multi-process settings) and the results of write operations.
Args:
state_dict (STATE_DICT_TYPE): The local subset of the model state dict to save.
storage_writer (StorageWriter): Handles writing data to persistent storage.
planner (Optional[SavePlanner]): Custom save planner; uses DefaultSavePlanner if None.
part_idx (int): Offset index used to generate unique storage prefixes (e.g., "__2_").
Returns:
Tuple[Optional[Metadata], List[Any]]:
- global_metadata: checkpoint metadata.
- all_writes: Results of the write operations.
"""
if planner is None:
planner = DefaultSavePlanner()
global_metadata = None
ckpt_kwargs = {}
ckpt_id = getattr(storage_writer, "checkpoint_id", None)
if ckpt_id is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
"The function definition for SavePlanner.set_up_planner has been updated"
" to include the storage_meta argument. Please update your implementation"
" to include this parameter."
)
planner.set_up_planner(state_dict, True)
else:
planner.set_up_planner(
state_dict=state_dict,
storage_meta=storage_meta,
is_coordinator=True
)
storage_writer.set_up_storage_writer(True)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
nonlocal global_metadata
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
all_local_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i+part_idx}_"))
for i, plan in enumerate(all_local_plans)
]
return all_local_plans
local_plan = local_step()
all_local_plan = global_step([local_plan])[0]
@_dcp_method_logger(**ckpt_kwargs)
def write_data():
final_local_plan = planner.finish_plan(all_local_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)
all_writes.wait()
return all_writes.value()
all_writes = write_data()
return global_metadata, all_writes
def save_metadata(
global_metadata,
all_writes,
storage_writer: StorageWriter
):
"""
Finalize a Distributed Checkpoint (DCP) by saving its global metadata.
Args:
global_metadata: The consolidated metadata describing the entire checkpoint.
Typically generated during the global planning phase.
all_writes: Results from previous tensor write operations (e.g., list of
written file names or async write futures). Used by the writer
to finalize references in the metadata.
storage_writer (StorageWriter): The writer responsible for persisting data
and metadata to the underlying storage backend.
"""
ckpt_kwargs = {}
ckpt_id = getattr(storage_writer, "checkpoint_id", None)
if ckpt_id is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint():
storage_writer.finish(metadata=global_metadata, results=all_writes)
return global_metadata
return finish_checkpoint()
def merge_meta_info(
global_meta_infos: List[Metadata],
):
"""
Merge multiple DCP (Distributed Checkpoint) metadata objects into a single unified Metadata instance.
This function is typically used when a checkpoint has been saved in multiple shards or parts
(e.g., via partial saves), each producing its own Metadata object. The merge combines:
- `state_dict_metadata`: mapping of tensor names to their storage/sharding info,
- `planner_data`: auxiliary data used by the SavePlanner (e.g., layout hints, version info).
It assumes that keys across shards are disjoint (i.e., no overlapping tensor names),
so simple dictionary merging via `**` is safe.
Args:
global_meta_infos (List[Metadata]): A list of Metadata objects from individual shards.
Must be non-empty to produce a valid result.
Returns:
Metadata: A merged Metadata object containing the union of all input metadata.
Returns None if the input list is empty.
"""
merged_data = reduce(
lambda acc, x: acc.__class__(
state_dict_metadata={**acc.state_dict_metadata, **x.state_dict_metadata},
planner_data={**acc.planner_data, **x.planner_data}
),
global_meta_infos[1:],
global_meta_infos[0]
) if global_meta_infos else None
return merged_data
def load_metadata(
storage_reader: StorageReader
):
"""
Load the global metadata of a Distributed Checkpoint (DCP) from persistent storage.
This function uses a `StorageReader` to read the checkpoint's central metadata file,
which typically contains:
- `state_dict_metadata`
- `planner_data`
The metadata is essential for correctly reconstructing the full state dict during loading.
Args:
storage_reader (StorageReader): An object capable of reading checkpoint data from
the underlying storage backend (e.g., filesystem, cloud).
Returns:
Metadata: The deserialized global metadata object describing the entire checkpoint.
"""
ckpt_kwargs = {}
ckpt_id = getattr(storage_reader, "checkpoint_id", None)
if ckpt_id is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def read_metadata():
metadata = storage_reader.read_metadata()
return metadata
return read_metadata()
def partial_load_dcp_state_dict(
metadata: Metadata,
storage_reader: StorageReader,
planner: Optional[LoadPlanner] = None,
):
"""
Load a partial subset of a Distributed Checkpoint (DCP) state dictionary.
This function is designed for scenarios where only a portion of the full model
needs to be loaded (e.g., loading specific layers or shards). The input `metadata`
describes only the relevant tensors to load—not the entire checkpoint.
It uses a LoadPlanner and StorageReader to:
- Plan which tensors to load based on the provided metadata,
- Read the corresponding data from storage,
- Populate a local `state_dict` in-place.
Note: The resulting `state_dict` will contain only the keys covered by the input `metadata`.
Args:
metadata (Metadata): Partial metadata describing the subset of tensors to load.
Must include entries in `state_dict_metadata` for the desired keys.
storage_reader (StorageReader): Handles reading tensor data from persistent storage.
planner (Optional[LoadPlanner]): Custom load planner. If not provided, uses
`_EmptyStateDictLoadPlanner`, which initializes an
empty state dict and populates it during loading.
Returns:
STATE_DICT_TYPE: A state dictionary containing only the tensors specified in `metadata`.
"""
state_dict: STATE_DICT_TYPE = {}
if planner is None:
planner = _EmptyStateDictLoadPlanner()
ckpt_kwargs = {}
ckpt_id = getattr(storage_reader, "checkpoint_id", None)
if ckpt_id is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
planner.set_up_planner(state_dict, metadata, True)
storage_reader.set_up_storage_reader(metadata, True)
local_plan = planner.create_local_plan()
local_plan = storage_reader.prepare_local_plan(local_plan)
return local_plan
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
all_local_plans = planner.create_global_plan(all_local_plans)
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
return all_local_plans
local_plan = local_step()
central_plan = global_step([local_plan])[0]
@_dcp_method_logger(**ckpt_kwargs)
def read_data():
final_local_plan = planner.finish_plan(central_plan)
all_reads = storage_reader.read_data(final_local_plan, planner)
all_reads.wait()
read_data()
return state_dict
def extract_metadata(
selected_keys: List[str],
metadata: Metadata
):
"""
Extract a partial Metadata object containing only the entries corresponding to the given keys.
This function filters a full DCP (Distributed Checkpoint) Metadata instance to produce a
reduced version that includes only the tensors (or state dict keys) specified in `selected_keys`.
It selectively subsets three core components of the metadata:
- `state_dict_metadata`
- `storage_data`
- `planner_data`
The resulting partial metadata can be used to load or save only a subset of the checkpoint,
enabling efficient partial operations (e.g., loading specific layers of a large model).
Args:
selected_keys (List[str]): A list of fully qualified names (FQNs) of tensors to retain.
Only metadata entries matching these keys will be included.
metadata (Metadata): The complete metadata object from a DCP checkpoint.
Returns:
Metadata: A new Metadata instance containing only the entries associated with `selected_keys`.
"""
partial_state_dict_metadata = {}
partial_storage_data_metadata = {}
partial_planner_data = {}
for dcp_key, tensor_storage_metadata in metadata.state_dict_metadata.items():
if dcp_key in selected_keys:
partial_state_dict_metadata.update({dcp_key: tensor_storage_metadata})
for metadataindex, storage_info in metadata.storage_data.items():
if metadataindex.fqn in selected_keys:
partial_storage_data_metadata.update({metadataindex: storage_info})
for dcp_key, state_dict_key_tuple in metadata.planner_data.items():
if dcp_key in selected_keys:
partial_planner_data.update({dcp_key: state_dict_key_tuple})
partial_metadata = Metadata(
state_dict_metadata=partial_state_dict_metadata,
storage_data=partial_storage_data_metadata,
planner_data=partial_planner_data
)
return partial_metadata