"""Pretrain utilities."""
import os
import sys
import gc
import os
from functools import wraps
import torch
import torch_npu
from datetime import datetime
from megatron.training import get_args
from megatron.training import get_timers
from megatron.training import is_last_rank
from megatron.core import parallel_state
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.training import print_rank_0
from megatron.training.arguments import parse_args
from megatron.training.global_vars import (set_args, get_tensorboard_writer, get_wandb_writer,
get_one_logger)
from megatron.training.training import num_floating_point_operations
from megatron.training.utils import print_rank_last, report_memory
from megatron.training.theoretical_memory_usage import report_theoretical_memory
from mindspeed.auto_settings.auto_settings import AutoSettings
from mindspeed.auto_settings.module.black.auto_patch import AutoPatcher
from mindspeed.core.memory.auto_pipeline.autopipeline import autopipeline_profiling
from mindspeed.core.performance.auto_pipeline_perf.autopipeline_perf import (autopipelineperf_profiling, check_out_of_memory,
calculate_num_of_activations, check_skip_profiling,
broadcast_skip_in_ranks)
from mindspeed.core.performance.auto_pipeline_perf.optimpipeline_solver import solve_optimpipeline, broadcast_oom_in_ranks, broadcast_mbs_in_ranks, save_profiling_data
from mindspeed.core.performance.auto_pipeline_perf.schedulepipeline_solver import (solve_pipelineschedule, broadcast_enable_schedule_in_ranks,
broadcast_scheduler_in_ranks, broadcast_layer_in_ranks,
all_gather_time, average_time_by_rank)
from mindspeed.core.memory.auto_pipeline.autopipeline_apply import apply_autopipeline
from mindspeed.core.memory.auto_pipeline.autopipeline_solver import solve_autopipeline, broadcast_policy_in_ranks, destroy_global_vars
from mindspeed.arguments import parse_args_wrapper
POLICY = None
OPTIMIZED_MBS_LIST = None
PP_SCHEDULE_LIST = None
OPTIMAL_LAYERS = None
ORIGIN_MBS = None
DATA_PARALLEL_SIZE = 1
ENABLE_SCHEDULER = False
FLOPS_COUNTER = None
RECORDED_COUNT = 0
TRAVERSED_COUNT = 0
def generated_flops_counter():
from torch_npu.utils.flops_count import FlopsCounter
global FLOPS_COUNTER
FLOPS_COUNTER = FlopsCounter()
def get_flops_counter():
global FLOPS_COUNTER
if FLOPS_COUNTER is None:
generated_flops_counter()
return FLOPS_COUNTER
def set_count(count):
global RECORDED_COUNT
global TRAVERSED_COUNT
RECORDED_COUNT = count[0]
TRAVERSED_COUNT = count[1]
def get_count():
global RECORDED_COUNT
global TRAVERSED_COUNT
if RECORDED_COUNT == 0 and TRAVERSED_COUNT == 0:
flops_counter = get_flops_counter()
count = flops_counter.get_flops()
set_count(count)
return RECORDED_COUNT, TRAVERSED_COUNT
def train_decorator(train):
@wraps(train)
def wrapper(*args, **kwargs):
args_ = get_args()
if args_.profile:
args_.profile_npu = True
args_.profile = False
else:
args_.profile_npu = False
is_profile = hasattr(args_, 'profile_npu') and args_.profile_npu \
and ((torch.distributed.get_rank() in args_.profile_ranks) or (-1 in args_.profile_ranks))
if is_profile:
active = args_.profile_step_end - args_.profile_step_start
skip_first = args_.profile_step_start
if args_.profile_with_cpu:
activities = [torch_npu.profiler.ProfilerActivity.NPU, torch_npu.profiler.ProfilerActivity.CPU]
else:
activities = [torch_npu.profiler.ProfilerActivity.NPU]
if args_.profile_level == 'level0':
profiler_level = torch_npu.profiler.ProfilerLevel.Level0
elif args_.profile_level == 'level1':
profiler_level = torch_npu.profiler.ProfilerLevel.Level1
elif args_.profile_level == 'level2':
profiler_level = torch_npu.profiler.ProfilerLevel.Level2
else:
raise ValueError(f"profiler_level only support level0, level1, level2, but gets {args_.profile_level}")
experimental_config = torch_npu.profiler._ExperimentalConfig(
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,
profiler_level=profiler_level,
l2_cache=False
)
with torch_npu.profiler.profile(
activities=activities,
record_shapes=args_.profile_record_shapes,
profile_memory=args_.profile_with_memory,
with_stack=args_.profile_with_stack,
experimental_config=experimental_config,
schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=active, repeat=1, skip_first=skip_first),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(args_.profile_save_path)
) as prof:
args_.prof = prof
return train(*args, **kwargs)
else:
return train(*args, **kwargs)
return wrapper
def train_step_decorator(train_step):
@wraps(train_step)
def wrapper(*args, **kwargs):
nonlocal train_step
args_ = get_args()
flop_count = None
if args_.op_cal_tflops:
flop_count = get_flops_counter()
flop_count.start()
if os.getenv('OOTB_OPTIMIZER_PROFILING_BLACK', 'FALSE') == 'TRUE':
custom_train_step = AutoPatcher(args_.prof_file).hook_train_step(train_step)
ret = custom_train_step(*args, **kwargs)
else:
ret = train_step(*args, **kwargs)
is_profile = (hasattr(args_, 'profile_npu') and args_.profile_npu
and (torch.distributed.get_rank() in args_.profile_ranks or -1 in args_.profile_ranks))
if is_profile:
args_.prof.step()
if args_.op_cal_tflops:
counts = flop_count.get_flops()
set_count(counts)
flop_count.stop()
return ret
return wrapper
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()
one_logger = get_one_logger()
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = value == float('inf') or \
value == -float('inf') or \
value != value
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'all-grads-sync',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
if one_logger:
job_name = os.environ.get('SLURM_JOB_NAME', None)
current_app_tag = f'{job_name}_{batch_size}_{args.world_size}'
one_logger.log_app_tag(current_app_tag)
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'batch-size': batch_size}, iteration)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({key: loss_dict[key]}, iteration)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'loss-scale': loss_scale}, iteration)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'world-size': args.world_size}, iteration)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'grad-norm': grad_norm}, iteration)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'params-norm': params_norm}, iteration)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if args.num_experts is not None:
moe_loss_scale = 1 / get_num_microbatches()
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
throughput = num_floating_point_operations(args, batch_size) / (
elapsed_time_per_iteration * 10**12 * args.world_size)
counts_0, counts_1 = get_count()
counts_0_tensor = torch.tensor([counts_0], device="npu")
counts_1_tensor = torch.tensor([counts_1], device="npu")
torch.distributed.all_reduce(
counts_0_tensor, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
counts_1_tensor, op=torch.distributed.ReduceOp.SUM
)
mfu = counts_0_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
hfu = counts_1_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
if wandb_writer:
wandb_writer.log({'iteration-time': elapsed_time_per_iteration},
iteration)
log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
if args.log_throughput:
log_string += f' theoretical throughput per NPU (TFLOP/s/NPU): {throughput:.1f} |'
log_string += f' actual throughput per NPU (TFLOP/s/NPU): {mfu:.1f} |'
log_string += f' actual throughput per NPU with recompute (TFLOP/s/NPU): {hfu:.1f} |'
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
log_string += ' learning rate: {:.6E} |'.format(learning_rate)
if args.decoupled_lr is not None and (parallel_state.is_pipeline_first_stage(ignore_virtual=True) or
parallel_state.is_pipeline_last_stage(ignore_virtual=True)):
assert decoupled_learning_rate is not None
log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate)
else:
assert decoupled_learning_rate is None
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag
def pretrain_decorator(pretrain):
@wraps(pretrain)
def wrapper(*args, **kwargs):
global POLICY
global OPTIMIZED_MBS_LIST
global PP_SCHEDULE_LIST
global OPTIMAL_LAYERS
global ORIGIN_MBS
global DATA_PARALLEL_SIZE
global ENABLE_SCHEDULER
new_parse_args = parse_args_wrapper(parse_args)
argument = new_parse_args(kwargs.get('extra_args_provider'), False)
if argument.auto_settings:
if argument.rank % torch.cuda.device_count() != 0:
return
set_args(argument)
settings = AutoSettings()
settings.auto_setting_fun(argument)
return
if argument.automated_pipeline and not argument.num_layer_list:
context, POLICY = autopipeline_profiling(args[1], args[2], args[3],
args[0], None, argument)
if context:
POLICY = solve_autopipeline(context)
parallel_state.destroy_global_memory_buffer()
parallel_state.destroy_model_parallel()
destroy_global_vars()
gc.collect()
torch.cuda.empty_cache()
if argument.automated_pipeline_perf:
ORIGIN_MBS = argument.micro_batch_size
is_skip, exist_policy = check_skip_profiling(argument, config_file="autopipeline_perf_config.json")
if not is_skip:
global_context = []
mbs_time, pp_schedule_time = 0, 0
mbs_tries = 1
num_forwards_first_stage = 0
is_oom = False
forward_time_dict = {}
backward_time_dict = {}
while mbs_tries < ORIGIN_MBS + 2:
context = autopipelineperf_profiling(mbs_tries, args[1], args[2], args[3],
args[0], None)
if mbs_tries == ORIGIN_MBS:
schedule_context = context
forward_time_list = all_gather_time(argument, schedule_context['fwd_time'])
forward_time_dict = average_time_by_rank(forward_time_list)
backward_time_list = all_gather_time(argument, schedule_context['bwd_time'])
backward_time_dict = average_time_by_rank(backward_time_list)
num_forwards_first_stage = calculate_num_of_activations(schedule_context)
parallel_state.destroy_global_memory_buffer()
parallel_state.destroy_model_parallel()
destroy_global_vars()
gc.collect()
torch.cuda.empty_cache()
global_context.append((context['fwd_time'], context['bwd_time'], context['comm_time']))
DATA_PARALLEL_SIZE = context['data_parallel_size']
if not is_oom:
is_oom = check_out_of_memory(argument, context, mbs_tries)
is_oom = broadcast_oom_in_ranks(0, is_oom)
mbs_tries += 1
if mbs_tries <= ORIGIN_MBS and is_oom:
raise AssertionError(
'A risk of Out of Memory could occur, please '
'reset to a smaller micro batch size.')
if mbs_tries > ORIGIN_MBS and is_oom:
break
if len(global_context) > 0:
OPTIMIZED_MBS_LIST, mbs_time = solve_optimpipeline(argument, DATA_PARALLEL_SIZE, global_context)
PP_SCHEDULE_LIST, pp_schedule_time, OPTIMAL_LAYERS = solve_pipelineschedule(argument, DATA_PARALLEL_SIZE, num_forwards_first_stage, forward_time_dict, backward_time_dict)
if torch.distributed.get_rank() == 0 and mbs_time > pp_schedule_time and num_forwards_first_stage > 2:
ENABLE_SCHEDULER = True
ENABLE_SCHEDULER = broadcast_enable_schedule_in_ranks(0, ENABLE_SCHEDULER)
optimized_policy = (ENABLE_SCHEDULER, OPTIMIZED_MBS_LIST, PP_SCHEDULE_LIST, OPTIMAL_LAYERS)
save_profiling_data(optimized_policy, config_file="autopipeline_perf_config.json")
else:
ENABLE_SCHEDULER = exist_policy[0]
OPTIMIZED_MBS_LIST = exist_policy[1]
PP_SCHEDULE_LIST = exist_policy[2]
OPTIMAL_LAYERS = exist_policy[3]
pretrain(*args, **kwargs)
if os.getenv('OOTB_OPTIMIZER_PROFILING', 'FALSE') == 'TRUE':
from mindspeed.auto_settings.module.parse.profiling_parse.profiling_node_parse import GatherNodeProfiling
res_dir = argument.profile_save_path
cur_rank = torch.distributed.get_rank()
if res_dir and cur_rank % torch.cuda.device_count() == 0:
GatherNodeProfiling(res_dir).parse_node_pkl(argument)
return wrapper
def setup_model_and_optimizer_decorator(setup_model_and_optimizer):
@wraps(setup_model_and_optimizer)
def wrapper(*args, **kwargs):
global POLICY
global OPTIMIZED_MBS_LIST
global PP_SCHEDULE_LIST
global OPTIMAL_LAYERS
global ENABLE_SCHEDULER
argument = get_args()
if argument.automated_pipeline and POLICY:
if torch.distributed.get_rank() == 0:
broadcast_policy_in_ranks(0, POLICY)
else:
broadcast_policy_in_ranks(0)
if argument.automated_pipeline_perf and ENABLE_SCHEDULER:
broadcast_scheduler_in_ranks(0, PP_SCHEDULE_LIST)
broadcast_layer_in_ranks(0, OPTIMAL_LAYERS)
elif argument.automated_pipeline_perf and OPTIMIZED_MBS_LIST:
broadcast_mbs_in_ranks(0, OPTIMIZED_MBS_LIST)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(*args, **kwargs)
if argument.recompute_module_list:
apply_autopipeline(model)
if os.getenv('OOTB_OPTIMIZER_PROFILING_BLACK', 'FALSE') == 'TRUE':
from mindspeed.auto_settings.module.black.auto_patch import AutoPatcher
auto_patcher = AutoPatcher(argument.prof_file)
auto_patcher.register_recursive_hook("", model, auto_patcher.context)
return model, optimizer, opt_param_scheduler
return wrapper