import argparse
import json
import shutil
from pathlib import Path
from typing import Dict, Optional
from pydantic import validate_arguments, DirectoryPath, FilePath
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from transformers import AutoConfig, AutoProcessor
from safetensors.torch import save_file
from checkpoint.common.permissions import set_directory_permissions
@validate_arguments
def load_dcp_state_dict(dcp_checkpoint_dir: DirectoryPath) -> STATE_DICT_TYPE:
sd: STATE_DICT_TYPE = {}
_load_state_dict(
sd,
storage_reader=FileSystemReader(str(dcp_checkpoint_dir)),
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
return sd['model'] if 'model' in sd else sd
def find_safetensors_index(directory: Path) -> Optional[FilePath]:
"""Find the .safetensors.index.json file in the given directory."""
if not directory.is_dir():
return None
for file in directory.iterdir():
if file.is_file() and file.name.endswith(".safetensors.index.json"):
return file
return None
@validate_arguments
def save_hf_weights(
save_path: Path,
model_assets_dir: DirectoryPath,
state_dict: Dict,
prefix: str = "",
):
save_path.mkdir(parents=True, exist_ok=True)
index_file: Optional[FilePath] = find_safetensors_index(Path(model_assets_dir))
if index_file is None:
raise FileNotFoundError(f"Could not find safetensors index file in directory {model_assets_dir}")
shutil.copy2(index_file, save_path)
with open(index_file, "r", encoding="utf-8") as f:
weight_map = json.load(f)["weight_map"]
state_dicts = []
for key, value in weight_map.items():
index = int(value.split("-")[1])
while index > len(state_dicts):
state_dicts.append({})
full_key = f"{prefix}{key}"
if full_key in state_dict:
state_dicts[index - 1][key] = state_dict[full_key]
else:
print(f"Missing key: '{full_key}' in state_dict")
metadata = {"format": "pt"}
for idx, sd in enumerate(state_dicts, start=1):
name = f"model-{idx:05d}-of-{len(state_dicts):05d}.safetensors"
save_file(sd, save_path / name, metadata=metadata)
set_directory_permissions(save_path)
@validate_arguments
def merge_dcp_to_hf(
load_dir: DirectoryPath,
save_dir: str | Path,
model_assets_dir: DirectoryPath,
prefix: str = "",
):
"""
Load model in torch DCP format and save in Hugging Face format.
"""
state_dict = load_dcp_state_dict(load_dir)
config = AutoConfig.from_pretrained(str(model_assets_dir))
processor = AutoProcessor.from_pretrained(str(model_assets_dir), trust_remote_code=True)
save_path = Path(save_dir)
config.save_pretrained(save_path)
processor.save_pretrained(save_path)
save_hf_weights(
save_path=save_path,
model_assets_dir=str(model_assets_dir),
state_dict=state_dict,
prefix=prefix,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--load-dir", type=str, required=True, help="Path to DCP checkpoint directory")
parser.add_argument("--save-dir", type=str, required=True, help="Path to save HF format model")
parser.add_argument("--model-assets-dir", type=str, required=True, help="Path to model assets (config, tokenizer, etc.)")
parser.add_argument("--prefix", type=str, default="", help="Key prefix for state dict (e.g., 'model.')")
args = parser.parse_args()
print(f"Merge Args: {args}")
merge_dcp_to_hf(
load_dir=args.load_dir,
save_dir=args.save_dir,
model_assets_dir=args.model_assets_dir,
prefix=args.prefix,
)
print(f"Merge to HF format success! Saved to: {args.save_dir}")