import os
import glob
import re
import json
from pathlib import Path
import stat
import shutil
import torch
from safetensors.torch import load_file, save_file
from checkpoint.common.permissions import set_directory_permissions
from checkpoint.common.constant import FILE_MODE
MEGATRON_LASTEST_ITERATION_FILE_NAME = "latest_checkpointed_iteration.txt"
MEGATRON_MODEL_KEY = "model"
MEGATRON_CKPT_NAME = "model_optim_rng.pt"
def load_from_mm(load_dir):
flags = os.O_RDONLY
mode = stat.S_IRUSR
iteration_path = os.path.join(load_dir, MEGATRON_LASTEST_ITERATION_FILE_NAME)
with os.fdopen(os.open(iteration_path, flags, mode)) as f:
latest_checkpointed_iteration = f.readline()
if latest_checkpointed_iteration == "release":
directory = "release"
else:
directory = "iter_{:07d}".format(int(latest_checkpointed_iteration))
pp_tp_state_dicts = {}
sub_dirs = os.listdir(os.path.join(load_dir, directory))
enable_pp = len(sub_dirs[0].split('_')) == 4
for sub_dir in sub_dirs:
state_dict_path = os.path.join(load_dir, directory, sub_dir, MEGATRON_CKPT_NAME)
state_dict = torch.load(state_dict_path, map_location='cpu', weights_only=False)
if enable_pp:
tp_rank, pp_rank = map(int, (sub_dir.split('_')[2:4]))
vpp_state_dicts = []
for key, vpp_state_dict in state_dict.items():
match = re.match(r'model(\d)', key)
if match:
number = int(match.group(1))
vpp_state_dicts.append((number, vpp_state_dict))
vpp_state_dicts.sort(key=lambda x: x[0])
state_dict = [vpp_state_dict for _, vpp_state_dict in vpp_state_dicts]
else:
pp_rank = 0
tp_rank = int(sub_dir.split('_')[2])
state_dict = state_dict[MEGATRON_MODEL_KEY]
pp_tp_state_dicts[(pp_rank, tp_rank)] = state_dict
pp_size = max([pp_tp_rank[0] for pp_tp_rank in pp_tp_state_dicts.keys()]) + 1
tp_size = max([pp_tp_rank[1] for pp_tp_rank in pp_tp_state_dicts.keys()]) + 1
state_dicts = []
for pp_rank in range(pp_size):
tp_state_dicts = []
for tp_rank in range(tp_size):
tp_state_dicts.append(pp_tp_state_dicts[((pp_rank, tp_rank))])
state_dicts.append(tp_state_dicts)
return state_dicts
def save_as_mm(save_dir, state_dicts, latest_checkpointed_iteration="release"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
flags = os.O_WRONLY | os.O_CREAT
mode = stat.S_IWUSR | stat.S_IRUSR
iteration_path = os.path.join(save_dir, MEGATRON_LASTEST_ITERATION_FILE_NAME)
with os.fdopen(os.open(iteration_path, flags, mode), 'w') as fout:
fout.write(latest_checkpointed_iteration)
if latest_checkpointed_iteration == "release":
directory = "release"
else:
directory = "iter_{:07d}".format(latest_checkpointed_iteration)
enable_pp = len(state_dicts) > 1
for pp_rank, tp_state_dicts in enumerate(state_dicts):
for tp_rank, state_dict in enumerate(tp_state_dicts):
if enable_pp:
state_dict_save_dir = os.path.join(save_dir, directory, f"mp_rank_{tp_rank:02d}_{pp_rank:03d}")
else:
state_dict_save_dir = os.path.join(save_dir, directory, f"mp_rank_{tp_rank:02d}")
os.makedirs(state_dict_save_dir)
save_path = os.path.join(state_dict_save_dir, MEGATRON_CKPT_NAME)
save_dict = {}
if isinstance(state_dict, list):
vpp_size = len(state_dict)
save_dict = {f"model{vpp_rank}": state_dict[vpp_rank] for vpp_rank in range(vpp_size)}
save_dict['checkpoint_version'] = 3.0
else:
save_dict[MEGATRON_MODEL_KEY] = state_dict
torch.save(save_dict, save_path)
set_directory_permissions(Path(save_dir))
def load_from_hf(hf_dir):
if not os.path.exists(hf_dir):
raise FileNotFoundError(f"Directory not found: {hf_dir}")
search_pattern = os.path.join(hf_dir, '**', '*.safetensors')
files = glob.glob(search_pattern, recursive=True)
if not files or len(files) == 0:
raise FileNotFoundError(f"No .safetensors files found in directory: {hf_dir}")
state_dict = {}
for safe_path in files:
state_dict.update(load_file(str(safe_path), device='cpu'))
return state_dict
def save_as_hf(state_dict, hf_dir, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
shutil.copy2(os.path.join(hf_dir, "diffusion_pytorch_model.safetensors.index.json"), save_dir)
shutil.copy2(os.path.join(hf_dir, "config.json"), save_dir)
index_json_path = os.path.join(save_dir, "diffusion_pytorch_model.safetensors.index.json")
with open(index_json_path, "r", encoding="utf-8") as file:
weight_map = json.load(file)["weight_map"]
state_dicts = []
for key, value in weight_map.items():
index = int(value.split("-")[1])
while index > len(state_dicts):
state_dicts.append({})
state_dicts[index - 1][key] = state_dict[key]
metadata = {"format": "pt"}
for index, state_dict in enumerate(state_dicts, start=1):
name = f'model-{index:05}-of-{len(state_dicts):05}.safetensors'
save_file(state_dict, os.path.join(save_dir, name), metadata=metadata)
set_directory_permissions(Path(save_dir))
def load_pt(source_path, module_name=None):
state_dict = torch.load(source_path, map_location='cpu')
if module_name:
state_dict = state_dict[module_name]
return state_dict
def save_as_pt(state_dict, target_path):
torch.save(state_dict, target_path)
os.chmod(target_path, FILE_MODE)
def load_from_layerzero(source_path, iteration=None, prefix=None, ema_model=False, for_release=True):
import mindspeed.megatron_adaptor
from mindspeed.core.distributed.layerzero.state.scripts.layerzero_checkpointer import LayerzeroCheckpoint
from mindspeed.core.distributed.layerzero.state.scripts.convert_to_megatron import _create_rank_checkpoint
if ema_model:
from mindspeed.core.distributed.layerzero.state.scripts.layerzero_checkpointer import remove_model_prefix
remove_model_prefix(prefix)
if iteration is None:
iteration = _get_latest_iter_number(source_path)
source_path = os.path.join(source_path, 'iter_{:07d}'.format(iteration))
layerzero_checkpoint = LayerzeroCheckpoint(source_path)
layerzero_checkpoint = _create_rank_checkpoint(
layerzero_checkpoint,
tp_index=0,
pp_index=0,
tp_degree=1,
pp_degree=1,
for_release=for_release
)
return layerzero_checkpoint
def _get_latest_iter_number(directory):
dir_path = Path(directory)
pattern = re.compile(r'^iter_(\d+)$')
max_num = -1
for subdir in dir_path.iterdir():
if subdir.is_dir():
match = pattern.match(subdir.name)
if match:
num = int(match.group(1))
if num > max_num:
max_num = num
if max_num == -1:
raise ValueError("No iter_xxxx directories found.")
return max_num