import inspect
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, get_args
import torch
from torch.optim.optimizer import ParamsT
from mindspeed.core.optimizer.muon.muon_utils import (
NSCoeffT,
get_muon_scale_factor,
newton_schulz_tp,
)
from mindspeed.core.optimizer.muon.orthogonalized_optimizer import OrthogonalizedOptimizer
from mindspeed.core.optimizer.muon.optimizer_config import ParamKey, ParamPredicate
from mindspeed.core.optimizer.muon.utils import get_pg_size
HAVE_EMERGING_OPTIMIZERS = True
logger = logging.getLogger(__name__)
def get_supported_coefficient_types() -> Tuple[str, ...]:
"""Return the coefficient types supported by the local Muon implementation.
Reads the members of the ``NSCoeffT`` Literal type so that new types
added upstream are automatically available without code changes here.
"""
return get_args(NSCoeffT)
def validate_coefficient_type(coefficient_type: str) -> None:
"""Raise ``ValueError`` if *coefficient_type* is not supported."""
supported = get_supported_coefficient_types()
if coefficient_type not in supported:
raise ValueError(f"Unsupported muon coefficient type '{coefficient_type}'. Supported types: {supported}")
def _eopt_init_state_fn(opt, config=None):
"""Initialize emerging optimizer state for torch_dist checkpoint format."""
for group in opt.param_groups:
opt._init_group(group, skip_non_grad_params=False)
def _default_param_overrides_factory() -> Dict[ParamKey, Dict[str, Any]]:
"""Default param overrides: route non-linear/embedding params to Adam."""
return {
ParamKey(predicate=ParamPredicate(name="nonlinear_or_embedding", fn=_is_nonlinear_or_embedding)): {
"optimizer": "adam"
}
}
@dataclass
class EmergingOptimizerEntry:
"""Everything needed to create and configure an emerging optimizer.
Attributes:
optimizer_cls: The torch optimizer class.
init_state_fn: Lazily initialises optimizer state (needed for checkpoint formats).
config_to_kwargs: ``(config, model_chunks, pg_collection) -> dict`` of constructor kwargs.
default_param_overrides: Per-parameter config overrides applied automatically
(e.g. route non-linear params to Adam).
"""
optimizer_cls: type
init_state_fn: Callable = _eopt_init_state_fn
config_to_kwargs: Optional[Callable] = None
default_param_overrides: Dict[ParamKey, Dict[str, Any]] = field(default_factory=_default_param_overrides_factory)
def _create_emerging_optimizer(config, param_groups, eopt_name, model_chunks, pg_collection):
"""Instantiate an emerging optimizer and return it with its init_state_fn."""
entry = _EMERGING_OPTIMIZERS[eopt_name]
if entry.config_to_kwargs is not None:
eopt_kwargs = entry.config_to_kwargs(config, model_chunks, pg_collection)
else:
eopt_kwargs = {}
optimizer = entry.optimizer_cls(param_groups, **eopt_kwargs)
return optimizer, entry.init_state_fn
def _is_nonlinear_or_embedding(param):
"""True for parameters that should NOT use the emerging optimizer."""
return getattr(param, "is_embedding_or_output_parameter", False) or len(param.shape) != 2
def _get_qkv_split_shapes(model_cfg) -> List[int]:
"""Compute QKV split shapes from model config."""
return [
model_cfg.num_attention_heads // model_cfg.num_query_groups * model_cfg.kv_channels,
model_cfg.kv_channels,
model_cfg.kv_channels,
]
_EMERGING_OPTIMIZERS: Dict[str, EmergingOptimizerEntry] = {}
class TensorParallelMuon(OrthogonalizedOptimizer):
"""Tensor Parallel Muon optimizer."""
def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
split_qkv: bool = False,
is_qkv_fn: Optional[Callable[[torch.Tensor], bool]] = None,
qkv_split_shapes: Optional[Tuple[int, int, int]] = None,
fp32_matmul_prec: str = "medium",
coefficient_type: NSCoeffT = "quintic",
num_ns_steps: int = 5,
scale_mode: str = "spectral",
extra_scale_factor: float = 1.0,
pg_collection: Optional[Any] = None,
tp_mode: Literal["blockwise", "duplicated", "distributed"] = "duplicated",
) -> None:
validate_coefficient_type(coefficient_type)
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
def scaled_orthogonalize_fn(
grad: torch.Tensor,
tp_group: Optional[torch.distributed.ProcessGroup],
partition_dim: Optional[int] = None,
) -> torch.Tensor:
logger.debug(
"Orthogonalizing grad with %s steps, %s coefficient, %s scale mode, extra_scale_factor=%s",
num_ns_steps,
coefficient_type,
scale_mode,
extra_scale_factor,
)
size = [grad.size(-2), grad.size(-1)]
if partition_dim is not None:
size[partition_dim] *= get_pg_size(tp_group)
orth_grad = newton_schulz_tp(
grad,
steps=num_ns_steps,
coefficient_type=coefficient_type,
tp_group=tp_group,
partition_dim=partition_dim,
tp_mode="duplicated" if tp_mode == "blockwise" else tp_mode,
)
scale_factor = get_muon_scale_factor(size[0], size[1], mode=scale_mode)
return orth_grad * scale_factor * extra_scale_factor
self.pg_collection = pg_collection
self.tp_mode = tp_mode
self.split_qkv = split_qkv
self.is_qkv_fn = is_qkv_fn
self.qkv_split_shapes = qkv_split_shapes
weight_decay_method = "decoupled" if use_decoupled_weight_decay else "l2"
OrthogonalizedOptimizer.__init__(
self,
params,
lr,
momentum,
nesterov=nesterov,
weight_decay=weight_decay,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Orthogonalize the momentum.
Args:
p: The parameter tensor. i is necessary to pass param tensor in addition to
momentum because a lot of information is only available in the param tensor,
attributes for example.
grad: The momentum tensor.
Returns:
The orthogonalized gradient tensor.
"""
if self.pg_collection:
tp_group = self.pg_collection.expt_tp if getattr(p, "expert_tp", False) else self.pg_collection.tp
else:
tp_group = None
partition_dim = None if self.tp_mode == "blockwise" else getattr(p, "partition_dim", None)
if partition_dim == -1:
partition_dim = None
if self.split_qkv and self.is_qkv_fn(p):
grad_shape = grad.shape
logger.debug(
"qkv split grad shape %s, split shapes %s",
grad_shape,
self.qkv_split_shapes,
)
num_query_groups = grad_shape[0] // sum(self.qkv_split_shapes)
qkv_grads = torch.split(
grad.view(num_query_groups, sum(self.qkv_split_shapes), -1),
self.qkv_split_shapes,
dim=1,
)
qkv_grads = [g.reshape(-1, grad_shape[-1]) for g in qkv_grads]
qkv_grads = [
self.scaled_orthogonalize_fn(g, tp_group, partition_dim).view(num_query_groups, -1, grad_shape[-1])
for g in qkv_grads
]
grad = torch.cat(qkv_grads, dim=1).view(grad_shape)
else:
grad = self.scaled_orthogonalize_fn(grad, tp_group, partition_dim)
return grad
def _kwargs_from_config(optimizer_cls: type, prefix: str, config) -> Dict[str, Any]:
"""Match ``optimizer_cls.__init__`` parameters to config attributes.
For each init parameter, looks for ``{prefix}_{name}`` on *config* first,
then falls back to ``{name}`` (unprefixed). ``self`` and ``params`` are
always skipped.
"""
skip_params = {"self", "params"}
kwargs: Dict[str, Any] = {}
for name in inspect.signature(optimizer_cls.__init__).parameters:
if name in skip_params:
continue
prefixed = f"{prefix}_{name}"
if hasattr(config, prefixed):
kwargs[name] = getattr(config, prefixed)
elif hasattr(config, name):
kwargs[name] = getattr(config, name)
return kwargs
def _muon_config_to_kwargs(config, model_chunks, pg_collection) -> Dict[str, Any]:
"""Convert OptimizerConfig to TensorParallelMuon constructor kwargs."""
kwargs = _kwargs_from_config(TensorParallelMuon, "muon", config)
kwargs["is_qkv_fn"] = lambda p: getattr(p, "is_qkv", False)
kwargs["qkv_split_shapes"] = _get_qkv_split_shapes(model_chunks[0].config)
kwargs["pg_collection"] = pg_collection
return kwargs
_EMERGING_OPTIMIZERS.update(
{
"muon": EmergingOptimizerEntry(
optimizer_cls=TensorParallelMuon,
init_state_fn=_eopt_init_state_fn,
config_to_kwargs=_muon_config_to_kwargs,
),
}
)