"""Feature gate for backporting Muon optimizer support to Megatron 0.12.x."""
import warnings
from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
def _add_choice(parser: ArgumentParser, option: str, value: str) -> None:
for action in parser._actions:
if option in action.option_strings and action.choices is not None:
if value not in action.choices:
action.choices = list(action.choices) + [value]
return
def _add_apply_wd_to_qk_layernorm_arg(group, parser: ArgumentParser) -> None:
for action in parser._actions:
if "--apply-wd-to-qk-layernorm" in action.option_strings:
return
group.add_argument(
"--apply-wd-to-qk-layernorm",
action="store_true",
help="Apply weight decay to qk layernorm as a special case.",
)
class MuonOptimizerFeature(MindSpeedFeature):
def __init__(self):
super().__init__("muon-optimizer", optimization_level=0)
def register_args(self, parser: ArgumentParser):
_add_choice(parser, "--optimizer", "muon")
_add_choice(parser, "--optimizer", "dist_muon")
group = parser.add_argument_group(title=self.feature_name)
group.add_argument(
"--muon-momentum",
type=float,
default=0.95,
help="Momentum factor for Muon optimizer.",
)
group.add_argument(
"--muon-no-split-qkv",
action="store_false",
default=True,
dest="muon_split_qkv",
help="Whether to split QKV parameters for Muon optimizer.",
)
group.add_argument(
"--muon-nesterov",
action="store_true",
help="Whether to use Nesterov-style momentum in Muon.",
)
group.add_argument(
"--muon-scale-mode",
type=str,
default="spectral",
choices=["spectral", "unit_rms_norm", "shape_scaling"],
help="Scale mode for Muon optimizer.",
)
group.add_argument(
"--muon-fp32-matmul-prec",
type=str,
default="medium",
choices=["low", "medium", "high"],
help="FP32 matmul precision for Newton-Schulz iteration.",
)
group.add_argument(
"--muon-coefficient-type",
type=str,
default="quintic",
help="Newton-Schulz coefficient type for Muon optimizer.",
)
group.add_argument(
"--muon-num-ns-steps",
type=int,
default=5,
help="Number of Newton-Schulz steps for Muon optimizer.",
)
group.add_argument(
"--muon-tp-mode",
type=str,
default="blockwise",
choices=["blockwise", "duplicated", "distributed"],
help="How to perform NS calculation for tensor-parallel weights.",
)
group.add_argument(
"--muon-extra-scale-factor",
type=float,
default=1.0,
help="Additional scale factor for the Muon update.",
)
group.add_argument(
"--muon-scalar-optimizer",
type=str,
default="adam",
choices=["adam", "lion"],
help="Optimizer for scalar/non-matrix params when using Muon.",
)
_add_apply_wd_to_qk_layernorm_arg(group, parser)
def post_validate_args(self, args):
if getattr(args, "optimizer", None) == "dist_muon":
warnings.warn(
"optimizer='dist_muon' is deprecated. Use --optimizer muon --use-distributed-optimizer instead.",
DeprecationWarning,
stacklevel=2,
)
args.optimizer = "muon"
args.use_layer_wise_distributed_optimizer = True
if getattr(args, "optimizer", None) == "muon" and getattr(args, "use_distributed_optimizer", False):
args.use_layer_wise_distributed_optimizer = True
args.use_distributed_optimizer = False
if getattr(args, "optimizer", None) == "muon" and getattr(
args, "overlap_param_gather_with_optimizer_step", False
):
warnings.warn(
"MindSpeed Muon on Megatron 0.12.1 does not support "
"--overlap-param-gather-with-optimizer-step; disabling it for this path.",
UserWarning,
stacklevel=2,
)
args.overlap_param_gather_with_optimizer_step = False
def validate_args(self, args):
if getattr(args, "optimizer", None) != "muon":
return
if getattr(args, "fp16", False):
raise AssertionError("Muon optimizer does not support fp16; use bf16 or fp32.")
if getattr(args, "use_custom_fsdp", False) or getattr(args, "use_torch_fsdp2", False):
raise AssertionError("Muon optimizer does not support FSDP in this MindSpeed patch.")
if getattr(args, "overlap_param_gather", False):
if not getattr(args, "use_layer_wise_distributed_optimizer", False):
raise AssertionError(
"Muon optimizer requires layer-wise distributed optimizer when "
"--overlap-param-gather is enabled. Use --use-distributed-optimizer "
"or optimizer='dist_muon'."
)
if not getattr(args, "overlap_grad_reduce", False):
raise AssertionError("Must use --overlap-param-gather with --overlap-grad-reduce.")
def register_patches(self, patch_manager, args):
if getattr(args, "optimizer", None) not in ("muon", "dist_muon"):
return
from mindspeed.core.optimizer.muon.adaptor import (
add_muon_tensor_model_parallel_attributes,
chained_optimizer_count_zeros,
copy_muon_tensor_model_parallel_attributes_wrapper,
count_zeros_fp32,
get_megatron_optimizer_muon_wrapper,
get_megatron_optimizer_based_on_param_groups_wrapper,
get_main_grads_for_grad_norm,
megatron_optimizer_count_zeros,
param_is_not_tensor_parallel_duplicate,
)
from mindspeed.core.optimizer.muon.muon import get_megatron_muon_optimizer
from mindspeed.core.optimizer.muon.optimizer_config import optimizer_config_init_wrapper
from mindspeed.core.optimizer.muon.checkpointing import (
load_base_checkpoint_layer_wise_optimizer_wrapper,
load_checkpoint_layer_wise_optimizer_wrapper,
save_checkpoint_layer_wise_optimizer_wrapper,
)
add_muon_tensor_model_parallel_attributes()
patch_manager.register_patch(
"megatron.core.optimizer.get_megatron_optimizer",
get_megatron_optimizer_muon_wrapper,
)
patch_manager.register_patch(
"megatron.core.optimizer._get_megatron_optimizer_based_on_param_groups",
get_megatron_optimizer_based_on_param_groups_wrapper,
)
patch_manager.register_patch(
"megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__",
optimizer_config_init_wrapper,
)
patch_manager.register_patch(
"megatron.core.tensor_parallel.layers.copy_tensor_model_parallel_attributes",
copy_muon_tensor_model_parallel_attributes_wrapper,
)
patch_manager.register_patch(
"megatron.core.optimizer.muon.get_megatron_muon_optimizer",
get_megatron_muon_optimizer,
create_dummy=True,
)
patch_manager.register_patch(
"megatron.core.tensor_parallel.layers.param_is_not_tensor_parallel_duplicate",
param_is_not_tensor_parallel_duplicate,
)
patch_manager.register_patch(
"megatron.core.optimizer.optimizer.MegatronOptimizer.get_main_grads_for_grad_norm",
get_main_grads_for_grad_norm,
)
patch_manager.register_patch(
"megatron.core.optimizer.clip_grads.count_zeros_fp32",
count_zeros_fp32,
)
patch_manager.register_patch(
"megatron.core.optimizer.optimizer.MegatronOptimizer.count_zeros",
megatron_optimizer_count_zeros,
)
patch_manager.register_patch(
"megatron.core.optimizer.optimizer.ChainedOptimizer.count_zeros",
chained_optimizer_count_zeros,
)
patch_manager.register_patch(
"megatron.training.checkpointing.save_checkpoint",
save_checkpoint_layer_wise_optimizer_wrapper,
)
patch_manager.register_patch(
"megatron.training.checkpointing._load_base_checkpoint",
load_base_checkpoint_layer_wise_optimizer_wrapper,
)
patch_manager.register_patch(
"megatron.training.checkpointing.load_checkpoint",
load_checkpoint_layer_wise_optimizer_wrapper,
)
from mindspeed.core.optimizer.muon.param_and_grad_buffer import (
distributed_data_parallel_start_param_sync_wrapper,
finish_grad_sync_wrapper,
finish_param_sync,
param_and_grad_bucket_group_init_wrapper,
set_layerwise_params_list,
start_param_sync,
)
patch_manager.register_patch(
"megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucket.set_layerwise_params_list",
set_layerwise_params_list,
create_dummy=True,
)
patch_manager.register_patch(
"megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.__init__",
param_and_grad_bucket_group_init_wrapper,
)
patch_manager.register_patch(
"megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.start_param_sync",
start_param_sync,
)
patch_manager.register_patch(
"megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.finish_param_sync",
finish_param_sync,
)
patch_manager.register_patch(
"megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.finish_grad_sync",
finish_grad_sync_wrapper,
)
patch_manager.register_patch(
"megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.start_param_sync",
distributed_data_parallel_start_param_sync_wrapper,
)