import argparse
from pathlib import Path
import torch
import torch.distributed.checkpoint as dcp
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.checkpoint import HuggingFaceStorageReader
from torchtitan.components.checkpoint import ModelWrapper
@torch.inference_mode()
def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
train_spec = train_spec_module.get_train_spec(model_name)
model_args = train_spec.model_args[model_flavor]
with torch.device("cpu"):
model = train_spec.model_cls(model_args)
model = ModelWrapper(model)
sd_adapter = train_spec.state_dict_adapter(
model_args, None
)
assert (
sd_adapter is not None
), "trying to convert checkpoint from HF to DCP safetensors format, but sd_adapter is not provided."
state_dict = model._get_state_dict()
hf_state_dict = sd_adapter.to_hf(state_dict)
dcp.load(
hf_state_dict,
storage_reader=HuggingFaceStorageReader(path=input_dir),
)
state_dict = sd_adapter.from_hf(hf_state_dict)
dcp.save(
state_dict,
checkpoint_id=output_dir,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert HF checkpoint to DCP format.")
parser.add_argument(
"input_dir", type=Path, help="Input directory with HF checkpoint"
)
parser.add_argument("output_dir", type=Path, help="Output directory for DCP.")
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
args = parser.parse_args()
convert_from_hf(
args.input_dir,
args.output_dir,
args.model_name,
args.model_flavor,
)