import os
from datetime import datetime
import torch
from megatron.core import parallel_state
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.transformer.moe.moe_utils import track_moe_metrics
from megatron.training import get_args, get_timers
from megatron.training.training import num_floating_point_operations
from megatron.training.utils import print_rank_last, report_memory
from megatron.training.theoretical_memory_usage import report_theoretical_memory
from megatron.training.global_vars import (
get_tensorboard_writer,
get_wandb_writer,
get_one_logger
)
from .tflops_utils import get_count
def training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad
):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
wandb_writer = get_wandb_writer()
one_logger = get_one_logger()
advanced_iters_key = 'advanced iterations'
skipped_iters_key = 'skipped iterations'
nan_iters_key = 'nan iterations'
if not skipped_iter:
total_loss_dict[advanced_iters_key] = total_loss_dict.get(
advanced_iters_key, 0) + 1
else:
if advanced_iters_key not in total_loss_dict:
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = total_loss_dict.get(
skipped_iters_key, 0) + skipped_iter
got_nan = False
for key in loss_dict:
if not skipped_iter:
total_loss_dict[key] = total_loss_dict.get(
key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key]
else:
value = loss_dict[key].float().sum().item()
is_nan = not torch.isfinite(value)
got_nan = got_nan or is_nan
total_loss_dict[nan_iters_key] = total_loss_dict.get(
nan_iters_key, 0) + int(got_nan)
timers_to_log = [
'forward-backward',
'forward-compute',
'backward-compute',
'batch-generator',
'forward-recv',
'forward-send',
'backward-recv',
'backward-send',
'forward-send-forward-recv',
'forward-send-backward-recv',
'backward-send-forward-recv',
'backward-send-backward-recv',
'forward-backward-send-forward-backward-recv',
'layernorm-grads-all-reduce',
'embedding-grads-all-reduce',
'all-grads-sync',
'params-all-gather',
'optimizer-copy-to-main-grad',
'optimizer-unscale-and-check-inf',
'optimizer-clip-main-grad',
'optimizer-count-zeros',
'optimizer-inner-step',
'optimizer-copy-main-to-model-params',
'optimizer']
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
if one_logger:
job_name = os.environ.get('SLURM_JOB_NAME', None)
current_app_tag = f'{job_name}_{batch_size}_{args.world_size}'
one_logger.log_app_tag(current_app_tag)
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
if args.log_timers_to_tensorboard and \
(iteration % args.tensorboard_log_interval == 0):
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
if args.log_learning_rate_to_tensorboard:
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.log_batch_size_to_tensorboard:
writer.add_scalar('batch-size', batch_size, iteration)
writer.add_scalar('batch-size vs samples', batch_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'batch-size': batch_size}, iteration)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key + ' vs samples', loss_dict[key],
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({key: loss_dict[key]}, iteration)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'loss-scale': loss_scale}, iteration)
if args.log_world_size_to_tensorboard:
writer.add_scalar('world-size', args.world_size, iteration)
writer.add_scalar('world-size vs samples', args.world_size,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'world-size': args.world_size}, iteration)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'grad-norm': grad_norm}, iteration)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'params-norm': params_norm}, iteration)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if args.num_experts is not None:
moe_loss_scale = 1 / get_num_microbatches()
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
throughput = num_floating_point_operations(args, batch_size) / (
elapsed_time_per_iteration * 10**12 * args.world_size)
counts_0, counts_1 = get_count()
counts_0_tensor = torch.tensor([counts_0], device="npu")
counts_1_tensor = torch.tensor([counts_1], device="npu")
torch.distributed.all_reduce(
counts_0_tensor, op=torch.distributed.ReduceOp.SUM
)
torch.distributed.all_reduce(
counts_1_tensor, op=torch.distributed.ReduceOp.SUM
)
mfu = counts_0_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
hfu = counts_1_tensor.cpu().item() / (10 ** 12 * elapsed_time_per_iteration * args.world_size)
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
if wandb_writer:
wandb_writer.log({'iteration-time': elapsed_time_per_iteration},
iteration)
log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples: {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time_per_iteration * 1000.0)
if args.log_throughput:
log_string += f' theoretical throughput per NPU (TFLOP/s/NPU): {throughput:.1f} |'
log_string += f' actual throughput per NPU (TFLOP/s/NPU): {mfu:.1f} |'
log_string += f' actual throughput per NPU with recompute (TFLOP/s/NPU): {hfu:.1f} |'
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
log_string += ' learning rate: {:.6E} |'.format(learning_rate)
if args.decoupled_lr is not None and (parallel_state.is_pipeline_first_stage(ignore_virtual=True) or
parallel_state.is_pipeline_last_stage(ignore_virtual=True)):
assert decoupled_learning_rate is not None
log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate)
else:
assert decoupled_learning_rate is None
log_string += ' global batch size: {:5d} |'.format(batch_size)
for key in total_loss_dict:
if key not in [advanced_iters_key, skipped_iters_key,
nan_iters_key]:
avg = total_loss_dict[key].item() / \
float(max(1, total_loss_dict[advanced_iters_key]))
if avg > 0.0:
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda')
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format(
total_loss_dict[nan_iters_key])
total_loss_dict[advanced_iters_key] = 0
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True)
report_memory('(after {} iterations)'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
return report_memory_flag