import time
from functools import wraps
import os
from logging import getLogger
import torch
from megatron.core import mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.utils import get_model_config
from megatron.training import one_logger_utils
from megatron.training.checkpointing import save_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.training import append_to_progress_log
from megatron.training.training import setup_model_and_optimizer
from megatron.training.training import build_train_valid_test_data_iterators
from megatron.training.training import train
from megatron.training.training import evaluate_and_print_results
from megatron.training.training import print_datetime
from megatron.training.training import preprocess_common_state_dict
from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size,
get_num_microbatches,
update_num_microbatches)
from megatron.training.utils import (
calc_params_l2_norm,
check_adlr_autoresume_termination,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model)
from megatron.training.global_vars import (
get_args,
get_signal_handler,
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
from megatron.training.async_utils import maybe_finalize_async_save
from mindspeed.core.transformer.moe.expert_placement.executor import build_param_params_module_mlp_map
from mindspeed.core.transformer.moe.expert_placement.executor import expert_weight_and_optimizer_state_placement
from mindspeed.core.transformer.moe.expert_placement.planner import print_expert_load
_BASE_TIME = 1742613446
_TRAIN_START_TIME = time.time()
LOG = getLogger(__name__)
@torch.no_grad()
def update_ema(
ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999
) -> None:
"""
Step the EMA model towards the current model.
"""
from collections import OrderedDict
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
if name == "pos_embed":
continue
if param.requires_grad == False:
continue
param_data = param.data
ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay)
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
if args.use_ema:
unwrapped_model = unwrap_model(model)
for model_chunk in unwrapped_model:
update_ema(model_chunk.ema, model_chunk, optimizer=optimizer)
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def pretrain(
train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults=None,
get_embedding_ranks=None,
get_position_embedding_ranks=None,
non_loss_data_func=None,
):
if args_defaults is None:
args_defaults = {}
initialize_megatron(
extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks
)
if (os.getenv("OOTB_OPTIMIZER_PARSE_ARGS", "FALSE") == "TRUE"):
args = get_args()
if not args.vocab_size:
from megatron.training.tokenizer.tokenizer import build_tokenizer
tokenizer = build_tokenizer(args)
args.vocab_size = tokenizer.vocab_size
from mindspeed.auto_settings.module.parse.profiling_parse import get_settings
get_settings(args, args.profile_save_path)
print_rank_0("================OOTB_OPTIMIZER_PARSE_ARGS END EXIT!====================")
return
if 'init_func' in args_defaults:
init_func = args_defaults['init_func']
init_func()
args = get_args()
timers = get_timers()
if args.log_progress:
append_to_progress_log("Starting job")
set_jit_fusion_options()
global _TRAIN_START_TIME
start_time_tensor = torch.npu.FloatTensor([_TRAIN_START_TIME - _BASE_TIME])
LOG.info(
"original _TRAIN_START_TIME is (seconds) %s, start_time_tensor is %s",
_TRAIN_START_TIME,
start_time_tensor.item(),
)
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item() + _BASE_TIME
LOG.info("adjusted _TRAIN_START_TIME is (seconds) %s", _TRAIN_START_TIME)
app_metrics = {}
app_metrics['app_start_time'] = round(_TRAIN_START_TIME * 1000.0)
app_metrics['app_model_init_start_time'] = round(_TRAIN_START_TIME * 1000.0)
print_rank_0(
"time to initialize megatron (seconds): {:.3f}".format(
time.time() - _TRAIN_START_TIME
)
)
print_datetime('after megatron is initialized')
app_metrics['app_model_init_finish_time'] = one_logger_utils.get_timestamp_in_ms()
args = get_args()
timers = get_timers()
one_logger_utils.on_pretrain_start()
if args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
checkpointing_context = {
'local_checkpoint_manager': BasicLocalCheckpointManager(
args.non_persistent_local_ckpt_dir
)
}
else:
checkpointing_context = {}
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
app_metrics['app_build_optimizer_start_time'] = one_logger_utils.get_timestamp_in_ms()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
model_provider, model_type, checkpointing_context=checkpointing_context)
if getattr(args, "enable_expert_placement", False):
params_module_mlp_map = build_param_params_module_mlp_map(model)
if hasattr(optimizer, "chained_optimizers"):
for optimizer_sub in optimizer.chained_optimizers:
optimizer_sub.params_module_mlp_map = params_module_mlp_map
else:
optimizer.params_module_mlp_map = params_module_mlp_map
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
app_metrics['app_build_optimizer_finish_time'] = one_logger_utils.get_timestamp_in_ms()
config = get_model_config(model[0])
if (os.getenv("OOTB_OPTIMIZER_PARSE_MODEL", "FALSE") == "TRUE"):
from mindspeed.auto_settings.module.parse.profiling_parse import get_model_params
get_model_params(model, mpu.get_pipeline_model_parallel_rank(), args.profile_save_path)
print_rank_0("================OOTB_OPTIMIZER_PARSE_MODEL END EXIT!====================")
return
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(
barrier=True)
if args.virtual_pipeline_model_parallel_size is not None:
train_data_iterator = []
valid_data_iterator = []
test_data_iterator = []
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
iterators = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
train_data_iterator.append(iterators[0])
valid_data_iterator.append(iterators[1])
test_data_iterator.append(iterators[2])
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test-data-iterators-setup').stop()
print_datetime('after dataloaders are built')
app_metrics['app_build_dataiters_finish_time'] = one_logger_utils.get_timestamp_in_ms()
one_logger_utils.track_config_flags(args.train_iters, args.skip_train, args.do_train,
args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().init_workload_monitoring()
ft_timeouts = ft_integration.get_rank_monitor_client().timeouts
print_rank_0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
'train/valid/test-data-iterators-setup'], barrier=True)
one_logger = get_one_logger()
one_logger and one_logger.log_metrics(app_metrics)
if not args.skip_train:
print_rank_0('training ...')
if args.dataloader_type == 'cyclic' and args.retro_project_dir:
assert args.retro_cyclic_train_iters is not None
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
iteration = 0
if args.do_train and args.train_iters > 0:
if getattr(args, "enable_expert_placement", False):
expert_weight_and_optimizer_state_placement(args, model, optimizer)
if getattr(args, "print_expert_load", False):
print_expert_load(args, model, iteration)
iteration, num_floating_point_operations_so_far = train(
forward_step_func,
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func)
print_datetime('after training is done')
if args.save and iteration != 0 and iteration % args.save_interval != 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT),
preprocess_common_state_dict_fn=preprocess_common_state_dict)
one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
})
else:
print_rank_0('skipping training (--skip-train is on) ...')
iteration = args.iteration
if args.do_valid:
prefix = f'iteration {iteration} on validation set'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)
if args.do_test:
prefix = f'iteration {iteration} on test set'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
iteration, process_non_loss_data_func, config,
verbose=True, write_to_tensorboard=not args.skip_train,
non_loss_data_func=non_loss_data_func)
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
maybe_finalize_async_save(blocking=True)
one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms()
})
one_logger_utils.finish()
def num_floating_point_wrapper(fn):
@wraps(fn)
def wrapper(args, batch_size):
args.num_layers -= len(args.noop_layers) if isinstance(args.noop_layers, set) else 0
res = fn(args, batch_size)
args.num_layers += len(args.noop_layers) if isinstance(args.noop_layers, set) else 0
return res
return wrapper
def get_device_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
backend = torch.distributed.get_backend()
local_rank = args[0]
if backend == 'hccl':
if local_rank is None:
device = torch.device('cuda')
else:
device = torch.device(f'cuda:{local_rank}')
else:
device = func(*args, **kwargs)
return device
return wrapper
def get_device_arch_version():
"""Returns GPU arch version (8: Ampere, 9: Hopper, 10: Blackwell, ...)"""
return 8