import functools
import json
import logging
import os
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
import torch
from .weight_utils import convert_expert_format, detect_expert_format
logger = logging.getLogger(__name__)
@dataclass
class SaveConfig:
enabled: bool = False
save_format: str = "dcp"
save_expert_format: str | None = None
hf_save_dir: str | None = None
num_experts: int = 0
_adapter: Any | None = field(default=None, repr=False)
_patched: bool = field(default=False, repr=False)
def reset(self):
self.enabled = False
self.save_format = "dcp"
self.save_expert_format = None
self.hf_save_dir = None
self._adapter = None
self._patched = False
self.num_experts = 0
def get_adapter(self):
return self._adapter
def set_adapter(self, adapter):
self._adapter = adapter
def is_patched(self):
return self._patched
def set_patched(self, patched):
self._patched = patched
_config = SaveConfig()
_original_save: Callable | None = None
_original_model_states_sd: Callable | None = None
def configure_from_model_args(model_args: Any, adapter: Any | None = None):
def get_config(attr: str, default):
val = getattr(model_args, attr, None)
return val if val is not None else default
_config.enabled = get_config("save_patch_enabled", False)
_config.save_format = get_config("save_format", "dcp")
_config.save_expert_format = get_config("save_expert_format", None)
_config.hf_save_dir = get_config("hf_save_dir", None)
_config.set_adapter(adapter)
def is_enabled() -> bool:
return _config.enabled
def _convert_state_dict_for_save(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Convert state_dict to Expert format according to configuration."""
if not _config.save_expert_format:
return state_dict
current = detect_expert_format(state_dict)
target = _config.save_expert_format
if _config.save_format == "hf":
target = "standard"
if current != target and current != "none":
logger.info(f"Converting expert format for save: {current} -> {target}")
return convert_expert_format(state_dict, target)
return state_dict
def _get_total_experts() -> int:
model_args = getattr(_config.get_adapter(), "model_args", None)
return model_args.moe_args.num_experts
def _convert_to_hf_and_save(state_dict: dict[str, Any], output_dir: str):
"""Convert to HF format and save. Support EP saving"""
is_distributed = torch.distributed.is_initialized()
rank = torch.distributed.get_rank() if is_distributed else 0
world_size = torch.distributed.get_world_size() if is_distributed else 1
is_main = rank == 0
try:
from safetensors.torch import save_file
excluded = ("train_state", "optimizer", "lr_scheduler", "dataloader")
model_state_dict = {
k: v
for k, v in state_dict.items()
if not any(k.startswith(p) for p in excluded)
}
total_experts = _get_total_experts()
experts_per_rank = (
total_experts // world_size if total_experts > 0 and world_size > 1 else 0
)
hf_state_dict = _config.get_adapter().to_hf(model_state_dict)
expert_keys = sorted([k for k in hf_state_dict.keys() if ".experts" in k])
non_expert_keys = sorted(
[k for k in hf_state_dict.keys() if ".experts." not in k]
)
expert_ids_in_hf = set()
for key in expert_keys:
match = re.search(r"\.experts\.(\d+)\.", key)
if match:
expert_ids_in_hf.add(int(match.group(1)))
min_expert_id = min(expert_ids_in_hf) if expert_ids_in_hf else 0
max_expert_id = max(expert_ids_in_hf) if expert_ids_in_hf else 0
already_global = (max_expert_id >= experts_per_rank) or (
rank > 0 and min_expert_id > 0
)
if is_main:
logger.info(
f"experts_per_rank = {experts_per_rank}, already_global = {already_global}"
)
cpu_non_expert = {}
for key in non_expert_keys:
value = hf_state_dict[key]
if hasattr(value, "full_tensor"):
full_value = value.full_tensor()
if is_main:
cpu_non_expert[key] = full_value.cpu()
else:
if is_main:
if isinstance(value, torch.Tensor):
cpu_non_expert[key] = value.cpu()
else:
cpu_non_expert[key] = value
local_expert_state = {}
for key in expert_keys:
value = hf_state_dict[key]
if already_global:
new_key = key
else:
def remap_expert_id(m):
local_id = int(m.group(1))
global_id = rank * experts_per_rank + local_id
return f".experts.{global_id}."
new_key = re.sub(r"\.experts\.(\d+)\.", remap_expert_id, key)
if hasattr(value, "full_tensor"):
full_value = value.full_tensor()
local_expert_state[new_key] = full_value.cpu()
elif isinstance(value, torch.Tensor):
local_expert_state[new_key] = value.cpu()
else:
local_expert_state[new_key] = value
if is_distributed and world_size > 1 and experts_per_rank > 0:
all_expert_states = [None] * world_size
torch.distributed.all_gather_object(all_expert_states, local_expert_state)
else:
all_expert_states = [local_expert_state]
if is_main:
merged_experts = {}
for expert_dict in all_expert_states:
if expert_dict:
merged_experts.update(expert_dict)
cpu_state = {**cpu_non_expert, **merged_experts}
if not cpu_state:
logger.warning("No tensors to save")
else:
os.makedirs(output_dir, exist_ok=True)
save_path = os.path.join(output_dir, "model.safetensors")
save_file(cpu_state, save_path)
total_size = sum(
t.numel() * t.element_size()
for t in cpu_state.values()
if isinstance(t, torch.Tensor)
)
saved_expert_ids = set()
for key in cpu_state.keys():
match = re.search(r"\.experts\.(\d+)\.", key)
if match:
saved_expert_ids.add(int(match.group(1)))
index = {
"metadata": {"total_size": total_size},
"weight_map": {
k: "model.safetensors" for k in sorted(cpu_state.keys())
},
}
index_path = os.path.join(output_dir, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
sorted_ids = sorted(saved_expert_ids)
logger.info(
f"Saved HF checkpoint to {save_path}"
f"({len(cpu_state)} tensors, {len(saved_expert_ids)} experts: {sorted_ids[0]}-{sorted_ids[-1]},"
f"{total_size / 1e9:.2f}GB)"
)
except Exception as e:
if is_main:
logger.error(f"Failed to save HF checkpoint: {e}", exc_info=True)
finally:
if is_distributed:
torch.distributed.barrier()
def _create_patched_model_states_sd(original_method: Callable) -> Callable:
"""Wrap _flattened_model_states_sd to convert to expert format"""
@functools.wraps(original_method)
def patched(self, *args, **kwargs):
states = original_method(self, *args, **kwargs)
if _config.enabled and _config.save_expert_format:
return _convert_state_dict_for_save(states)
return states
return patched
def _create_patched_save(original_save: Callable) -> Callable:
"""Package the save method to support dual-format saving"""
@functools.wraps(original_save)
def patched_save(self, curr_step: int, last_step: bool = False):
if not _config.enabled:
return original_save(self, curr_step, last_step)
should_save = True
if hasattr(self, "enable_checkpoint") and not self.enable_checkpoint:
return None
if hasattr(self, "interval"):
should_save = last_step or (curr_step % self.interval == 0)
if not should_save:
return None
result = None
if _config.save_format == "dcp":
result = original_save(self, curr_step, last_step)
if _config.save_format == "hf" and _config.hf_save_dir:
state_dict = self._flattened_model_states_sd()
hf_dir = os.path.join(_config.hf_save_dir, f"step_{curr_step}")
_convert_to_hf_and_save(state_dict, hf_dir)
return result
return patched_save
def apply_patch() -> bool:
"""apply monkey patch"""
global _original_save, _original_model_states_sd
if _config.is_patched():
return True
if not _config.enabled:
return False
try:
from torchtitan.components.checkpoint import CheckpointManager
if hasattr(CheckpointManager, "_flattened_model_states_sd"):
if _original_model_states_sd is None:
_original_model_states_sd = CheckpointManager._flattened_model_states_sd
CheckpointManager._flattened_model_states_sd = (
_create_patched_model_states_sd(_original_model_states_sd)
)
if _original_save is None:
_original_save = CheckpointManager.save
CheckpointManager.save = _create_patched_save(_original_save)
_config.set_patched(True)
return True
except Exception as e:
logger.error(f"Failed to apply checkpoint patch: {e}", exc_info=True)
return False
def patch_llama4_checkpoint_support():
import torchtitan.models.llama4 as llama4_module
from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.protocols.train_spec import TrainSpec
from torchtitan_npu.models.llama4.model.state_dict_adapter import (
dcp_load,
Llama4StateDictAdapterNpu,
)
CheckpointManager.dcp_load = dcp_load
original = llama4_module.get_train_spec
def patch_llama4_spec() -> TrainSpec:
spec = original()
spec.state_dict_adapter = Llama4StateDictAdapterNpu
return spec
llama4_module.get_train_spec = patch_llama4_spec
def patch_dsv3_checkpoint_support():
import torchtitan.models.deepseek_v3 as deepseek_v3_module
from torchtitan.protocols.train_spec import TrainSpec
from torchtitan_npu.models.deepseek_v3.model.state_dict_adapter import (
DeepSeekV3StateDictAdapterNpu,
)
original = deepseek_v3_module.get_train_spec
def patch_deepseek_v3_spec() -> TrainSpec:
spec = original()
spec.state_dict_adapter = DeepSeekV3StateDictAdapterNpu
return spec
deepseek_v3_module.get_train_spec = patch_deepseek_v3_spec