from functools import wraps, partial
import inspect
import torch
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False, dense_or_moe_group: str = None):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync:
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
return
if dense_or_moe_group is None:
bucket_groups = self.bucket_groups + self.expert_parallel_bucket_groups
elif dense_or_moe_group == 'dense':
bucket_groups = self.bucket_groups
elif dense_or_moe_group == 'moe':
bucket_groups = self.expert_parallel_bucket_groups
for bucket_group in bucket_groups:
bucket_group.start_param_sync(force_sync=force_sync)
def step_with_ready_grads_distrib_opti_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
is_moe_param = getattr(self, 'is_moe_param', None)
needrecover = False
for model_chunk in self.model_chunks:
if 'dense_or_moe_group' in inspect.signature(model_chunk.start_param_sync).parameters:
needrecover = True
model_chunk.start_param_sync = partial(model_chunk.start_param_sync, dense_or_moe_group=is_moe_param)
update_successful = func(*args, **kwargs)
if needrecover:
for model_chunk in self.model_chunks:
model_chunk.start_param_sync = model_chunk.start_param_sync.func
return update_successful
return wrapper
def get_megatron_optimizer_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
chained_optimizer = func(*args, **kwargs)
if hasattr(chained_optimizer, 'chained_optimizers'):
if 'model_chunks' in kwargs:
model_chunks = kwargs['model_chunks']
elif len(args) > 1:
model_chunks = args[1]
else:
return chained_optimizer
for optimizer in chained_optimizer.chained_optimizers:
optimizer.is_moe_param = 'dense'
is_expert_parallel = False
for model_chunk in model_chunks:
ddp_config = model_chunk.ddp_config
if ddp_config.use_custom_fsdp:
named_parameters = model_chunk.optimizer_named_parameters()
else:
named_parameters = model_chunk.named_parameters()
for name, param in named_parameters:
if (
ddp_config.use_custom_fsdp
and ddp_config.data_parallel_sharding_strategy == "optim_grads_params"
):
param_shard = param
param = param.orig_param
if not param.requires_grad:
continue
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
break
if is_expert_parallel:
break
if is_expert_parallel:
chained_optimizer.chained_optimizers[-1].is_moe_param = 'moe'
return chained_optimizer
return wrapper