import torch
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
from megatron.core import parallel_state
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
finish_grad_sync_ldt(bucket_group)
def finish_grad_sync_ldt(bucket_group: _ParamAndGradBucketGroup):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
:param bucket_group: The bucket group to finish grad sync.
"""
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return
bucket_group.param_gather_dispatched = False
if not bucket_group.ddp_config.overlap_grad_reduce:
bucket_group.start_grad_sync()
return
if bucket_group.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(bucket_group.communication_stream)
return
if bucket_group.grad_reduce_handle is None:
raise AssertionError(
f'Communication call has not been issued for this bucket '
f'({len(bucket_group.params_with_grad)}/{len(bucket_group.params)} params have grad available)'
)
for handle in bucket_group.grad_reduce_handle:
handle.wait()
bucket_group.grad_reduce_handle = None
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
return
if not self.ddp_config.overlap_grad_reduce:
raise AssertionError('register_grad_ready() should only be called when overlap_grad_reduce is True')
if self.is_last_microbatch:
if param not in self.param_to_bucket:
raise AssertionError('Param is not in the bucket group')
if param in self.params_with_grad:
raise AssertionError('Cannot set grad twice')
self.params_with_grad.add(param)
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()