import logging
import math
import warnings
from contextlib import nullcontext
from functools import wraps
from logging import getLogger
from typing import Dict, List
import torch
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.distributed.param_and_grad_buffer import (
BufferType,
dist_all_gather_func,
dist_reduce_scatter_func,
shard_buffer,
)
from megatron.core.utils import (
is_torch_min_version,
log_on_each_pipeline_stage,
)
from megatron.core.fp8_utils import is_float8tensor
from mindspeed.args_utils import get_full_args
logger = getLogger(__name__)
def pipe_register_grad_ready_wrapper(register_grad_ready):
@wraps(register_grad_ready)
def wrapper(self, param: torch.nn.Parameter):
assert (self.ddp_config.overlap_grad_reduce), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
from mindspeed.moe.pipe_experts import FLAG_GRAD_REDUCE
if self.is_last_microbatch and FLAG_GRAD_REDUCE:
register_grad_ready(self, param)
return wrapper
def reuse_fp32_param_param_and_grad_buffer_init_wrapper(init_func):
@wraps(init_func)
def reuse_fp32_param_param_and_grad_buffer_init(*args, **kwargs):
global_args = get_full_args()
math_ceil = math.ceil
if global_args.reuse_fp32_param and global_args.use_distributed_optimizer:
def ceil_even(x):
return math_ceil(math_ceil(x) / 2) * 2
math.ceil = ceil_even
init_func(*args, **kwargs)
if global_args.reuse_fp32_param and global_args.use_distributed_optimizer:
math.ceil = math_ceil
return reuse_fp32_param_param_and_grad_buffer_init
def start_param_sync(self, force_sync: bool = False):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert self.ddp_config.use_distributed_optimizer
if force_sync:
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
return
else:
assert self.param_gather_handle is None
async_op = self.ddp_config.overlap_param_gather and not force_sync
self.param_gather_handle = []
for bucket in self.buckets:
local_data_view = shard_buffer(bucket.param_data, self.intra_distributed_optimizer_instance_size)[
self.intra_distributed_optimizer_instance_rank
]
handle = dist_all_gather_func(
bucket.param_data,
local_data_view,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
self.param_gather_handle.append(handle)
if not async_op:
self.param_gather_handle = None
self.param_gather_dispatched = True
def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert self.ddp_config.use_distributed_optimizer
assert self.ddp_config.overlap_param_gather
if not self.param_gather_dispatched:
self.start_param_sync()
if self.param_gather_handle is not None:
for handle in self.param_gather_handle:
handle.wait()
self.param_gather_handle = None
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
if self.next_param_gather_bucket_group.param_gather_dispatched:
warnings.warn(
"The next bucket's parameter all-gather operation has already been "
"dispatched. This may be caused by a mismatch between the order of "
"parameter registration and forward pass execution, which will "
"hurt the communication-computation overlap performance."
)
else:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad:
self.check_grads(
check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
check_for_large=self.ddp_config.check_for_large_grads,
)
for bucket in self.buckets:
if bucket.gradient_scaling_factor != 1.0:
bucket.grad_data *= bucket.gradient_scaling_factor
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
async_op = (
self.ddp_config.overlap_grad_reduce
and self.ddp_config.num_distributed_optimizer_instances == 1
)
if (
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
stream_context = torch.cuda.stream(self.communication_stream)
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
if self.ddp_config.use_distributed_optimizer:
communication_group = self.intra_distributed_optimizer_instance_group
else:
communication_group = self.data_parallel_group
self.grad_reduce_handle = []
for bucket in self.buckets:
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(bucket.grad_data, self.intra_distributed_optimizer_instance_size)[
self.intra_distributed_optimizer_instance_rank
]
handle = dist_reduce_scatter_func(
local_data_view,
bucket.grad_data,
op=reduce_op,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
else:
handle = torch.distributed.all_reduce(
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
)
self.grad_reduce_handle.append(handle)
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):
self.grad_reduce_handle = []
for bucket in self.buckets:
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(bucket.grad_data, self.intra_distributed_optimizer_instance_size)[
self.intra_distributed_optimizer_instance_rank
]
handle = torch.distributed.all_reduce(
local_data_view,
op=reduce_op,
group=self.inter_distributed_optimizer_instance_group,
async_op=async_op,
)
self.grad_reduce_handle.append(handle)
if not async_op:
self.grad_reduce_handle = None
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
self.param_gather_dispatched = False
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
assert self.grad_reduce_handle is not None, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
for handle in self.grad_reduce_handle:
handle.wait()
self.grad_reduce_handle = None
def param_and_grad_buffer_init_pad(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
param_indices: List[int],
):
quant_args = get_full_args()
if getattr(quant_args, 'quant_grads', False):
requested_dtype = getattr(quant_args, 'quant_grads_dtype', None)
if isinstance(requested_dtype, str):
requested_dtype = requested_dtype.lower()
if requested_dtype == 'bf16':
grad_dtype = torch.bfloat16
else:
grad_dtype = torch.float16
self.ddp_config = ddp_config
self.params = params
self.param_indices = param_indices
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
self.buckets = []
self.param_to_bucket = {}
self.param_index_map = {}
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
element_size = 4 if param_dtype == torch.float else 2
global_args = get_full_args()
align_size = global_args.param_and_grad_buffer_pad // element_size
return _pad(bucket_end_index, self.data_parallel_world_size * align_size)
return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if self.ddp_config.use_distributed_optimizer:
return _pad(param_start_index, 64)
return param_start_index
param_start_index = 0
bucket_start_index = param_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _update_bucket_metadata(param_end_index: int) -> int:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal bucket_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(param_end_index - bucket_start_index)
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
self.bucket_indices.append((bucket_start_index, bucket_end_index))
bucket_start_index = bucket_end_index
bucket_params = set()
bucket_id += 1
return bucket_end_index
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
for param in params[::-1]:
this_numel = param.data.nelement()
param_start_index = _pad_start_of_param_if_needed(param_start_index)
if _does_param_require_new_bucket(param):
if self.ddp_config.use_distributed_optimizer:
if param_start_index % self.data_parallel_world_size != 0:
param_start_index = _pad_end_of_bucket_if_needed(param_start_index)
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_start_index)
param_end_index = param_start_index + this_numel
self.param_index_map[param] = (param_start_index, param_end_index, bucket_id)
bucket_params.add(param)
if (
bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
bucket_end_index = _update_bucket_metadata(param_end_index)
param_start_index = bucket_end_index
else:
param_start_index = param_end_index
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_end_index)
self.numel = bucket_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
bucket_params = []
bucket_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
param_start_index, param_end_index, bucket_id = self.param_index_map[param]
if self.param_data is not None:
old_param_data = param.data
new_param_data = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.PARAM
)
if is_float8tensor(param):
param._data = new_param_data
else:
param.data = new_param_data
assert old_param_data._base is None
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
bucket_start_index = bucket_end_index
bucket_params = []
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.append(param)
if len(bucket_params) > 0:
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index + 1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))