from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32
def step(self):
"""
This patch for solve count_zeros automatically used in ChainedOptimizer.
ChainedOptimizer will step all optimizers one by one.
"""
found_inf_flag = self.prepare_grads()
if found_inf_flag:
return False, None, None
grad_norm = self.get_grad_norm()
for optimizer in self.chained_optimizers:
if hasattr(optimizer, 'is_stub_optimizer') and optimizer.is_stub_optimizer:
continue
if optimizer.config.clip_grad > 0.0:
clip_grad_by_total_norm_fp32(
optimizer.get_parameters(),
max_norm=optimizer.config.clip_grad,
total_norm=grad_norm,
use_decoupled_grad=optimizer.config.use_precision_aware_optimizer,
)
num_zeros_in_grad = 0
for optimizer in self.chained_optimizers:
num_zeros_in_grad += (
optimizer.count_zeros() if optimizer.config.log_num_zeros_in_grad else 0
)
update_successful = self.step_with_ready_grads()
return update_successful, grad_norm, num_zeros_in_grad