import argparse
from typing import Optional
from pathlib import Path
from tqdm import tqdm
from torch.distributed.checkpoint import FileSystemWriter
from safetensors.torch import load_file
from checkpoint.common.constant import LATEST_TXT
from checkpoint.common.dcp_utils import partial_save_dcp_state_dict, merge_meta_info, save_metadata
from checkpoint.vlm_model.hf_to_mm import load_from_hf, save_by_dcp
from checkpoint.common.permissions import set_directory_permissions
def hf_to_dcp(
hf_dir: str,
dcp_dir: str,
prefix: Optional[str]
):
state_dict = load_from_hf(Path(hf_dir))
state_dict = {f"{prefix}{k}": v for k, v in state_dict.items()}
save_by_dcp(state_dict, Path(dcp_dir))
def hf_to_dcp_sharded(
hf_dir: str,
dcp_dir: str,
state_dict_convert_func: Optional[callable],
):
"""
By default, DCP shards are split following the same sharding logic as the original Hugging Face (HF) checkpoint weights.
"""
iter_name = "release"
save_root_dir = Path(dcp_dir)
save_path = save_root_dir.joinpath(iter_name)
save_path.mkdir(exist_ok=True, parents=True)
save_root_dir.joinpath(LATEST_TXT).write_text("release")
storage_writer = FileSystemWriter(save_path)
files = sorted(list(Path(hf_dir).glob("*.safetensors")))
meta_infos = []
all_writes = []
for i, safe_path in enumerate(tqdm(files, desc="Processing files")):
state_dict = load_file(str(safe_path), device="cpu")
state_dict = state_dict_convert_func(state_dict) if state_dict_convert_func else state_dict
save_dict = {
"model": state_dict
}
if i == 0:
save_dict["checkpoint_version"] = 3.0
global_meta, all_write = partial_save_dcp_state_dict(save_dict, storage_writer, part_idx=i)
meta_infos.append(global_meta)
all_writes.append(all_write)
merged_meta = merge_meta_info(meta_infos)
save_metadata(merged_meta, all_writes, storage_writer)
set_directory_permissions(Path(dcp_dir))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hf-dir", type=str, required=True, help="Path to HF format checkpoint directory")
parser.add_argument("--dcp-dir", type=str, required=True, help="Path to save torch_dcp format model")
parser.add_argument("--prefix", type=str, default="", help="Key prefix for state dict (e.g., 'model.')")
parser.add_argument("--sharded", action="store_true", help="Enable sharded conversion to reduce memory usage (process one shard at a time)")
args = parser.parse_args()
if args.sharded:
hf_to_dcp_sharded(
args.hf_dir,
args.dcp_dir,
state_dict_convert_func=lambda sd: {f"{args.prefix}{k}": v for k, v in sd.items()}
)
else:
hf_to_dcp(
args.hf_dir,
args.dcp_dir,
prefix=args.prefix
)