import os
import random
import warnings
import sys
import numpy as np
import torch
import torch.distributed as dist
from megatron.training.checkpointing import get_rng_state
from megatron.training.global_vars import get_args
from megatron.training.utils import print_rank_0
from megatron.core import mpu, tensor_parallel
from .state_dict import shard_state_dict, clean_ignored_modules, use_zero1_params
from .optim_state import _shard_optim_state_dict
PARALLE_STATE_KAY = "parallel_state"
MODEL_KEY = "model"
RNG_STATE_KEY = "rng_state"
SHRAD_KEY = "shard_state_dict"
EMA_MODEL_KEY = "ema_model"
OPTIM_STATE_KEY = "optimizer"
OPTIM_INFO_KEY = "optimizer_param_key_to_fqn"
OPTIM_SCHEDULER_KEY = "opt_param_scheduler"
LR_SCHEDULER_KEY = "lr_scheduler"
def save_checkpoint(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far=None,
checkpointing_context=None,
pipeline_rank=None,
expert_rank=None,
tensor_rank=None,
pipeline_parallel=None,
expert_parallel=None,
non_persistent_ckpt=False,
train_data_iterator=None,
preprocess_common_state_dict_fn=None):
"""Save a model checkpoint.
Checkpointing context is used to persist some checkpointing state
throughout a single job. Must be initialized externally (not used if None).
"""
args = get_args()
if not hasattr(args, "save"):
setattr(args, "save", "ckpt")
print_rank_0('saving checkpoint at iteration {:7d} to {} '.format(
iteration, args.save))
rng_state = get_rng_state(False)
checkpoint_name = get_checkpoint_name(args.save, iteration, release=False)
state_dict = generate_state_dict(args, model, optimizer, opt_param_scheduler, rng_state,
False, iteration)
state_dict[PARALLE_STATE_KAY] = generate_3D_parallel_state()
state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
ensure_directory_exists(checkpoint_name)
print_rank_0(f"Start Saving to {checkpoint_name}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
torch.save(state_dict, checkpoint_name)
if dist.is_initialized():
dist.barrier()
def generate_3D_parallel_state():
if not dist.is_initialized():
raise RuntimeError("Distributed environment is not initialized.")
if not mpu.is_initialized():
raise RuntimeError(
"Megatron's parallel utilities are not initialized.")
global_rank = dist.get_rank()
tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
tp_degree = mpu.get_tensor_model_parallel_world_size()
pp_degree = mpu.get_pipeline_model_parallel_world_size()
dp_degree = mpu.get_data_parallel_world_size()
parallel_state = {
'tp_rank': tp_rank,
'pp_rank': pp_rank,
'dp_rank': dp_rank,
'tp_degree': tp_degree,
'pp_degree': pp_degree,
'dp_degree': dp_degree,
'global_rank': global_rank
}
return parallel_state
def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None):
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
if iteration is not None:
state_dict['iteration'] = iteration
if not len(model) == 1:
raise ValueError(f"Only single model is supported, VPP not supported")
use_zero1_params(model[0])
state_dict[MODEL_KEY] = clean_ignored_modules(
model[0], model[0].state_dict())
state_dict[SHRAD_KEY] = shard_state_dict(model[0], state_dict[MODEL_KEY])
if not args.no_save_optim:
if optimizer is not None:
state_dict[OPTIM_STATE_KEY] = optimizer.state_dict()
state_dict[OPTIM_INFO_KEY] = _shard_optim_state_dict(
model[0], optimizer.optimizer, state_dict[OPTIM_STATE_KEY])
if getattr(args, "optimizer_selection", None) == 'fused_ema_adamw':
try:
ema_optimizer_applier(optimizer)
state_dict[EMA_MODEL_KEY] = clean_ignored_modules(
model[0], model[0].state_dict())
state_dict = ema_state_dict_to_cpu(
state_dict, EMA_MODEL_KEY)
ema_optimizer_restore(optimizer)
print_rank_0("Ema model successful saved in state_dict")
except KeyError:
warnings.warn(
f"ema_optimizer_applier failed with KeyError, ema_model not saved")
if opt_param_scheduler is not None:
state_dict[OPTIM_SCHEDULER_KEY] = \
opt_param_scheduler.state_dict()
if not args.no_save_rng:
state_dict[RNG_STATE_KEY] = rng_state
return state_dict
def get_checkpoint_name(checkpoints_path, iteration, release=False):
"""Determine the directory name for this rank's checkpoint."""
if checkpoints_path is None:
raise ValueError("checkpoints_path cannot be None")
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
common_path = os.path.join(checkpoints_path, directory)
global_rank = dist.get_rank()
return os.path.join(common_path, f"model_{global_rank}.pt")
def ensure_directory_exists(filename, check_parent=True):
"""Build filename's path if it does not already exists."""
if filename is None:
raise AssertionError(f"Got {filename} filename")
dirname = os.path.dirname(filename) if check_parent else filename
os.makedirs(dirname, exist_ok=True)
def load_layerzero_checkpoint(models, ckpt_dir, optimizer=None, opt_param_scheduler=None):
if ckpt_dir is None:
raise AssertionError(f"Got {ckpt_dir} filename")
if len(models) != 1:
raise ValueError(f"VPP is not supported by layerzero currently")
rank = dist.get_rank()
sd_file = os.path.join(ckpt_dir, f"model_{rank}.pt")
if not os.path.exists(sd_file):
raise FileNotFoundError(
f"No checkpoint found in load directory or pretrained directory: no such file {sd_file}")
args = get_args()
state_dict = torch.load(sd_file)
for i in range(len(models)):
models[i].load_state_dict(state_dict[MODEL_KEY], strict=False)
if not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict[OPTIM_STATE_KEY])
if opt_param_scheduler is not None:
if LR_SCHEDULER_KEY in state_dict:
opt_param_scheduler.load_state_dict(
state_dict[LR_SCHEDULER_KEY])
else:
opt_param_scheduler.load_state_dict(
state_dict[OPTIM_SCHEDULER_KEY])
except KeyError as e:
raise RuntimeError('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(ckpt_dir)) from e
args.num_floating_point_operations_so_far = state_dict.get(
'num_floating_point_operations_so_far', 0)
if args.finetune:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
iteration = 0
args.iteration = iteration
update_consumed_samples(args, state_dict)
resume_rng_states(args, state_dict)
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(f' successfully loaded checkpoint from {ckpt_dir} '
f'[ t {mpu.get_tensor_model_parallel_rank()}, '
f'p {mpu.get_pipeline_model_parallel_rank()} ] '
f'at iteration {iteration}')
return args.iteration, args.num_floating_point_operations_so_far
def update_consumed_samples(args, state_dict):
if 'args' in state_dict and not args.finetune:
checkpoint_args = state_dict['args']
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
update_num_microbatches(
consumed_samples=args.consumed_train_samples)
except ImportError:
pass
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')
def resume_rng_states(args, state_dict):
if not args.finetune and not args.no_load_rng:
try:
if RNG_STATE_KEY in state_dict:
if args.data_parallel_random_init:
rng_state = state_dict[RNG_STATE_KEY][mpu.get_data_parallel_rank(
)]
else:
rng_state = state_dict[RNG_STATE_KEY][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
if not state_dict['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError as e:
raise RuntimeError('Unable to load rng state from checkpoint '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the rng state, '
'exiting ...') from e
def ema_optimizer_applier(optimizer):
if hasattr(optimizer, "optimizer"):
optimizer.optimizer.store(optimizer.optimizer.param_groups)
optimizer.optimizer.copy_to()
return
def ema_optimizer_restore(optimizer):
if hasattr(optimizer, "optimizer"):
optimizer.optimizer.restore(optimizer.optimizer.param_groups)
return
def ema_state_dict_to_cpu(state_dict, ema_key):
for k, v in state_dict[ema_key].items():
if not torch.is_tensor(v):
continue
new_v = v.detach().cpu().clone()
state_dict[ema_key][k] = new_v
return state_dict