from typing import List, Optional
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron.core import parallel_state
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_attr_wrapped_model
from megatron.core.distributed.finalize_model_grads import _reshard_if_dtensor, _unshard_if_dtensor
from megatron.training import get_args
from mindspeed.core.tensor_parallel.tp_2d.group_api_2d import TPXCollectiveComm
from mindspeed.core.optimizer.low_precision import finalize_model_grads as quant_finalize
from megatron.core.transformer.moe.moe_utils import get_updated_expert_bias
def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False):
if use_custom_fsdp:
return "fsdp_managed_main_grad"
if hasattr(param, "main_grad"):
return "main_grad"
return "grad"
def allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if not param.requires_grad:
continue
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
grad = param.main_grad
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
layer_norm_2d_grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, "named_parameters")():
if param.requires_grad and getattr(param, "2d_tp", False):
layer_norm_2d_grad = param.main_grad
layer_norm_2d_grads.append(layer_norm_2d_grad.data)
if layer_norm_2d_grads:
coalesced = _flatten_dense_tensors(layer_norm_2d_grads)
torch.distributed.all_reduce(coalesced, group=TPXCollectiveComm.get_comm_group())
for buf, synced in zip(
layer_norm_2d_grads, _unflatten_dense_tensors(coalesced, layer_norm_2d_grads)
):
buf.copy_(synced)
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else:
model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
if grad is None:
return
args = get_args()
if args.use_quant_optimizer and quant_finalize is not None:
adjust_fn = getattr(quant_finalize, "_maybe_adjust_quant_scale", None)
if adjust_fn is not None:
adjust_fn(grad, parallel_state.get_embedding_group())
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _update_router_expert_bias_for_patch(model: List[torch.nn.Module], config: TransformerConfig):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if hasattr(module, 'expert_bias') and module.expert_bias is not None:
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
)
for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)