import copy
import warnings
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from mindspeed.core.optimizer.muon.emerging_optimizers import (
HAVE_EMERGING_OPTIMIZERS,
_EMERGING_OPTIMIZERS,
_create_emerging_optimizer,
)
from mindspeed.core.optimizer.muon.layer_wise_optimizer import LayerWiseDistributedOptimizer
from mindspeed.core.optimizer.muon.optimizer_config import (
ParamKey,
ParamPredicate,
ParamWithNamePredicate,
)
from mindspeed.core.optimizer.muon.optimizer_param_scheduler import (
ParamGroupOverride,
combine_param_group_overrides,
param_group_override_to_tuple,
)
from mindspeed.core.optimizer.muon.utils import LegacyProcessGroupCollection
_MUON_TENSOR_MODEL_PARALLEL_ATTRIBUTES = ("expert_tp", "is_qkv")
def add_muon_tensor_model_parallel_attributes():
"""Patch Megatron 0.12.1 tensor-parallel attribute defaults in memory."""
from megatron.core.tensor_parallel import layers
for attribute in _MUON_TENSOR_MODEL_PARALLEL_ATTRIBUTES:
layers._MODEL_PARALLEL_ATTRIBUTE_DEFAULTS.setdefault(attribute, False)
def copy_muon_tensor_model_parallel_attributes_wrapper(func):
"""Copy Muon tensor-parallel metadata when Megatron creates master params."""
@wraps(func)
def wrapper(destination_tensor, source_tensor):
result = func(destination_tensor, source_tensor)
for attribute in _MUON_TENSOR_MODEL_PARALLEL_ATTRIBUTES:
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
return result
return wrapper
def param_is_not_tensor_parallel_duplicate(param, tp_group=None):
"""Dev-style TP duplicate filter with a Megatron 0.12.1 fallback."""
from megatron.core import mpu
if hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel:
return True
if tp_group is not None:
return torch.distributed.get_rank(group=tp_group) == 0
return mpu.get_tensor_model_parallel_rank() == 0
def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]:
"""MegatronOptimizer.get_main_grads_for_grad_norm with explicit tp_group."""
from megatron.core.transformer.module import param_is_not_shared
grads_for_norm = []
for param in self.get_parameters():
if getattr(self.config, "use_precision_aware_optimizer", False):
grad = param.decoupled_grad if hasattr(param, "decoupled_grad") else None
else:
grad = param.grad
if (
grad is not None
and param_is_not_shared(param)
and param_is_not_tensor_parallel_duplicate(param, getattr(self, "tp_group", None))
):
grads_for_norm.append(grad)
return grads_for_norm
def count_zeros_fp32(
parameters,
grad_stats_parallel_group: torch.distributed.ProcessGroup,
use_decoupled_grad: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> float:
"""Count zero grads with explicit TP duplicate filtering."""
from megatron.core.transformer.module import param_is_not_shared
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
total_num_zeros = torch.tensor([0.0], dtype=torch.float, device="cuda")
data_parallel_group = None
for param in parameters:
grad_attr = "decoupled_grad" if use_decoupled_grad else "grad"
grad_not_none = hasattr(param, grad_attr) and getattr(param, grad_attr) is not None
if grad_not_none and param_is_not_shared(param) and param_is_not_tensor_parallel_duplicate(param, tp_group):
grad_obj = getattr(param, grad_attr)
data_parallel_group = get_data_parallel_group_if_dtensor(grad_obj, data_parallel_group)
grad = to_local_if_dtensor(grad_obj).detach()
total_num_zeros += grad.numel() - torch.count_nonzero(grad)
if data_parallel_group:
torch.distributed.all_reduce(total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=data_parallel_group)
torch.distributed.all_reduce(total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=grad_stats_parallel_group)
return total_num_zeros.item()
def megatron_optimizer_count_zeros(self) -> float:
"""MegatronOptimizer.count_zeros with explicit tp_group."""
return count_zeros_fp32(
self.get_parameters(),
grad_stats_parallel_group=self.get_grad_stats_parallel_group(),
use_decoupled_grad=getattr(self.config, "use_precision_aware_optimizer", False),
tp_group=getattr(self, "tp_group", None),
)
def chained_optimizer_count_zeros(self):
"""Avoid losing per-optimizer tp_group in ChainedOptimizer.count_zeros."""
num_zeros_in_grad = 0
for optimizer in self.chained_optimizers:
num_zeros_in_grad += optimizer.count_zeros() if optimizer.config.log_num_zeros_in_grad else 0
return num_zeros_in_grad
def _get_muon_config_overrides(
config,
no_weight_decay_cond: Optional[Callable],
scale_lr_cond: Optional[Callable],
lr_mult: float,
) -> Dict[ParamKey, ParamGroupOverride]:
config_overrides = {}
if no_weight_decay_cond is not None:
no_wd_param = ParamWithNamePredicate(
name="no_weight_decay_cond",
fn=lambda param, name: no_weight_decay_cond(name, param),
)
param_wd_mult_key = ParamKey(with_name_predicate=no_wd_param)
elif getattr(config, "apply_wd_to_qk_layernorm", False):
shape_1_not_qkln_param = ParamWithNamePredicate(
name="s1_not_qkln",
fn=lambda param, name: (
(len(param.shape) == 1 or name.endswith(".bias"))
and not ("q_layernorm." in name or "k_layernorm." in name)
),
)
param_wd_mult_key = ParamKey(with_name_predicate=shape_1_not_qkln_param)
else:
param_length_1_match = ParamPredicate(name="param_len_1", fn=lambda param: len(param.shape) == 1)
param_wd_mult_key = ParamKey(name="*.bias", predicate=param_length_1_match)
config_overrides[param_wd_mult_key] = ParamGroupOverride(wd_mult=0.0)
if scale_lr_cond is not None:
scale_lr_param = ParamWithNamePredicate(
name="scale_lr_cond",
fn=lambda param, name: scale_lr_cond(name, param),
)
config_overrides[ParamKey(with_name_predicate=scale_lr_param)] = ParamGroupOverride(lr_mult=lr_mult)
if getattr(config, "decoupled_lr", None) is not None:
decoupled_lr_config = ParamGroupOverride(max_lr=config.decoupled_lr)
if getattr(config, "decoupled_min_lr", None) is not None:
decoupled_lr_config["min_lr"] = config.decoupled_min_lr
config_overrides[ParamKey(attr="is_embedding_or_output_parameter")] = decoupled_lr_config
return config_overrides
def _get_param_groups(
model_chunks: List,
config,
config_overrides: Optional[Dict[ParamKey, ParamGroupOverride]],
) -> List[Dict]:
"""Create parameter groups for optimizer.
Creates parameter groups from provided optimizer config object.
NOTE There can be more than one match between a ParamKey and a parameter.
What we do is merge all of the matching ParamKey overrides into a single ParamGroupOverride
for that parameter and use that as the key for that parameter. Any parameters that get
the same set of merged overrides will be mapped into the same parameter group.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
config (OptimizerConfig): optimizer configuration object.
config_overrides (Optional[Dict[ParamKey, ParamGroupOverride]): optimizer overrides,
specified on a per-layer basis. NOTE: if you want to skip applying weight decay on bias
and length 1 parameters, and also do not want to do any other overrides, set this to an
empty dictionary rather than the default value of None.
Returns:
List of parameter groups.
"""
params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
param_overrides_list: List[ParamGroupOverride] = []
if config_overrides is not None:
for param_key, param_override in config_overrides.items():
if param_key.matches(param, name):
param_overrides_list.append(param_override)
if param_overrides_list:
param_override: Optional[ParamGroupOverride] = combine_param_group_overrides(param_overrides_list)
else:
param_override = None
is_expert_parallel = not getattr(param, "allreduce", True)
param_override_tuple: Optional[Tuple[Tuple[str, Any], ...]] = param_group_override_to_tuple(param_override)
key = (param_override_tuple, is_expert_parallel)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
params_key = list(params_map.keys())
if torch.distributed.is_available() and torch.distributed.is_initialized():
gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(gathered_params_key, params_key)
for keys in gathered_params_key:
for key in keys:
if key not in params_key:
params_key.append(key)
param_groups = []
for key in sorted(params_key, key=lambda x: (x[0] is not None, x[0])):
param_override_tuple, is_expert_parallel = key
params = params_map[key] if key in params_map else []
if param_override_tuple is None:
param_override = ParamGroupOverride()
else:
param_override = ParamGroupOverride({k: v for (k, v) in param_override_tuple})
uses_default_lr_schedule: bool = (not bool(param_override_tuple)) or not any(
["lr" in k for k in param_override]
)
default_config = ParamGroupOverride(
wd_mult=1.0,
lr_mult=1.0,
is_decoupled_lr=False,
max_lr=config.lr,
min_lr=config.min_lr,
)
if "params" in param_override:
raise ValueError("'params' should not be in param_override, this is a protected key")
param_group = {
"params": params,
"is_expert_parallel": is_expert_parallel,
"default_config": uses_default_lr_schedule,
**default_config,
**param_override,
}
param_groups.append(param_group)
return param_groups
def get_megatron_optimizer_based_on_param_groups_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
import inspect
signature = inspect.signature(func)
if "skip_megatron_wrapping" in signature.parameters or any(
parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values()
):
return func(*args, **kwargs)
skip_megatron_wrapping = kwargs.pop("skip_megatron_wrapping", False)
kwargs.pop("pg_collection", None)
if not skip_megatron_wrapping:
return func(*args, **kwargs)
config = kwargs["config"] if "config" in kwargs else args[0]
param_groups = kwargs["param_groups"] if "param_groups" in kwargs else args[2]
if getattr(config, "use_precision_aware_optimizer", False):
raise ValueError("skip_megatron_wrapping=True is incompatible with use_precision_aware_optimizer.")
if getattr(config, "optimizer_cpu_offload", False):
raise ValueError("skip_megatron_wrapping=True is incompatible with optimizer_cpu_offload.")
import megatron.core.optimizer as optimizer_mod
if param_groups:
if config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if hasattr(config, "optimizer_cuda_graph"):
kwargs["capturable"] = config.optimizer_cuda_graph
adam_cls = optimizer_mod.Adam
try:
supports_adam_w_mode = "adam_w_mode" in inspect.signature(adam_cls.__init__).parameters
except (TypeError, ValueError):
supports_adam_w_mode = not adam_cls.__module__.startswith("torch.optim")
if supports_adam_w_mode:
kwargs["adam_w_mode"] = getattr(config, "decoupled_weight_decay", True)
elif adam_cls.__module__.startswith("torch.optim"):
adam_cls = (
torch.optim.AdamW if getattr(config, "decoupled_weight_decay", True) else torch.optim.Adam
)
elif not getattr(config, "decoupled_weight_decay", True):
adam_cls = torch.optim.Adam
optimizer = adam_cls(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'lion':
try:
from emerging_optimizers.scalar_optimizers import Lion
except ImportError as exc:
raise ImportError(
"Lion optimizer requires emerging_optimizers >= 0.2. "
"Please install or upgrade it to use --optimizer lion."
) from exc
optimizer = Lion(
param_groups,
lr=config.lr,
betas=(
getattr(config, "lion_beta1", 0.95),
getattr(config, "lion_beta2", 0.98),
),
weight_decay=config.weight_decay,
)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
elif config.optimizer == 'sgd':
optimizer = optimizer_mod.SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
optimizer = None
init_state_fn = None
return optimizer, init_state_fn
return wrapper
def _get_megatron_emerging_optimizer(
config,
model_chunks: List,
config_overrides: Optional[Dict[ParamKey, Any]] = None,
pg_collection: Optional[LegacyProcessGroupCollection] = None,
):
"""Build an emerging optimizer using Megatron dev's high-level flow."""
from megatron.core.optimizer import _get_megatron_optimizer_based_on_param_groups
from megatron.core.optimizer.optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
)
eopt_name = config.optimizer
use_layer_wise = bool(getattr(config, "use_layer_wise_distributed_optimizer", False))
if eopt_name.startswith("dist_"):
bare_name = eopt_name[len("dist_") :]
warnings.warn(
f"optimizer='{eopt_name}' is deprecated. Use optimizer='{bare_name}' "
"with use_layer_wise_distributed_optimizer=True.",
DeprecationWarning,
stacklevel=3,
)
eopt_name = bare_name
use_layer_wise = True
if not HAVE_EMERGING_OPTIMIZERS:
raise ImportError(f"MindSpeed local emerging optimizer implementation is required for optimizer='{eopt_name}'.")
if eopt_name not in _EMERGING_OPTIMIZERS:
raise ValueError(f"Unsupported emerging optimizer: {eopt_name}")
if getattr(config, "fp16", False):
raise ValueError("emerging optimizer with fp16 is not supported.")
if pg_collection is None:
pg_collection = LegacyProcessGroupCollection()
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
continue
if "experts" in name and "shared" not in name:
param.expert_tp = True
if "linear_qkv.weight" in name and len(param.shape) == 2:
param.is_qkv = True
if config_overrides is None:
config_overrides = {}
config_overrides.update(_EMERGING_OPTIMIZERS[eopt_name].default_param_overrides)
all_param_groups = _get_param_groups(model_chunks, config, config_overrides)
grouped_param_groups = defaultdict(list)
for group in all_param_groups:
opt_name = group.get("optimizer", eopt_name)
is_expert = group["is_expert_parallel"] and not use_layer_wise
grouped_param_groups[(opt_name, is_expert)].append(group)
results = []
for (opt_name, is_expert), groups in grouped_param_groups.items():
if not groups:
continue
model_parallel_group = pg_collection.tp_ep_pp if is_expert else pg_collection.mp
if opt_name in _EMERGING_OPTIMIZERS:
optimizer, init_state_fn = _create_emerging_optimizer(
config, groups, eopt_name, model_chunks, pg_collection
)
if use_layer_wise:
result = (optimizer, init_state_fn)
else:
if getattr(config, "bf16", False):
optimizer = Float16OptimizerWithFloat16Params(optimizer, config, None, init_state_fn)
else:
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, "grad_stats_parallel_group", model_parallel_group)
setattr(optimizer, "tp_group", pg_collection.tp)
result = optimizer
else:
fallback_config = copy.copy(config)
fallback_config.optimizer = opt_name
fallback_config.use_distributed_optimizer = False
result = _get_megatron_optimizer_based_on_param_groups(
config=fallback_config,
model_chunks=model_chunks,
param_groups=groups,
model_parallel_group=model_parallel_group,
pg_collection=pg_collection,
skip_megatron_wrapping=use_layer_wise,
)
if use_layer_wise and not isinstance(result, tuple):
raise RuntimeError(
"Megatron _get_megatron_optimizer_based_on_param_groups must "
"support skip_megatron_wrapping for Muon layer-wise scalar fallback."
)
if not use_layer_wise and hasattr(result, "config"):
result.config = config
results.append(result)
if use_layer_wise:
base_optimizers, init_fns = (), ()
if results:
base_optimizers, init_fns = zip(*results)
return LayerWiseDistributedOptimizer(
list(base_optimizers),
config,
pg_collection=pg_collection,
init_state_fn_list=list(init_fns),
model_chunks=model_chunks if getattr(config, "overlap_param_gather", False) else None,
)
return ChainedOptimizer(results)
def get_megatron_optimizer_muon_wrapper(func):
"""Intercept Megatron's optimizer factory when --optimizer muon is selected."""
@wraps(func)
def wrapper(*args, **kwargs):
if args:
config = args[0]
model_chunks = args[1] if len(args) > 1 else kwargs.get("model_chunks")
else:
config = kwargs.get("config")
model_chunks = kwargs.get("model_chunks")
optimizer_name = getattr(config, "optimizer", None)
if optimizer_name not in ("muon", "dist_muon"):
return func(*args, **kwargs)
no_weight_decay_cond = kwargs.get("no_weight_decay_cond")
scale_lr_cond = kwargs.get("scale_lr_cond")
lr_mult = kwargs.get("lr_mult", 1.0)
if len(args) > 2:
no_weight_decay_cond = args[2]
if len(args) > 3:
scale_lr_cond = args[3]
if len(args) > 4:
lr_mult = args[4]
config_overrides = _get_muon_config_overrides(config, no_weight_decay_cond, scale_lr_cond, lr_mult)
return _get_megatron_emerging_optimizer(
config=config,
model_chunks=model_chunks,
config_overrides=config_overrides,
)
return wrapper