import fnmatch
from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Tuple, Union
import torch
from mindspeed.args_utils import get_full_args
_MUON_OPTIMIZER_CONFIG_DEFAULTS = {
"muon_momentum": 0.95,
"muon_split_qkv": True,
"muon_nesterov": False,
"muon_scale_mode": "spectral",
"muon_fp32_matmul_prec": "medium",
"muon_coefficient_type": "quintic",
"muon_num_ns_steps": 5,
"muon_tp_mode": "blockwise",
"muon_extra_scale_factor": 1.0,
"muon_scalar_optimizer": "adam",
"apply_wd_to_qk_layernorm": False,
"use_layer_wise_distributed_optimizer": False,
}
def optimizer_config_init_wrapper(init_func):
@wraps(init_func)
def optimizer_config_init(*args, **kwargs):
init_func(*args, **kwargs)
self = args[0]
args_namespace = get_full_args()
for name, default in _MUON_OPTIMIZER_CONFIG_DEFAULTS.items():
if not hasattr(self, name):
setattr(self, name, getattr(args_namespace, name, default))
return optimizer_config_init
@dataclass(frozen=True)
class ParamPredicate:
"""Wraps a matching function to make it hashable for ParamKey.
Example:
>>> shape_1_param = ParamPredicate(name="s1", fn=lambda param: len(param.shape) == 1)
>>> shape_1_param(torch.empty(10))
True
>>> shape_1_param_copy = ParamPredicate(name="s1", fn=lambda param: len(param.shape) == 1)
>>> shape_1_param == shape_1_param_copy # name is used to match
True
>>> {shape_1_param, shape_1_param_copy} == {shape_1_param} # set hashing works properly
NOTE:
__hash__ and __eq__ are automatically generated by @dataclass(frozen=True)
based solely on 'name' because we set compare=False/hash=False on 'fn'.
"""
name: str
fn: Callable[[torch.nn.Parameter], bool] = field(compare=False, hash=False)
def __call__(self, param: torch.nn.Parameter) -> bool:
return self.fn(param)
@dataclass(frozen=True)
class ParamWithNamePredicate:
"""Wraps a matching function to make it hashable for ParamKey.
Example:
>>> shape_1_not_qkln_param = ParamWithNamePredicate(
name="s1_not_qkln",
fn=lambda param, name: (
len(param.shape) == 1 or name.endswith(".bias")
and not ("q_layernorm." in name or "k_layernorm." in name)
)
)
>>> shape_1_not_qkln_param(torch.empty(10), "interesting.bias")
True
>>> shape_1_not_qkln_param(torch.empty(10), "interesting.q_layernorm.bias")
False
NOTE:
__hash__ and __eq__ are automatically generated by @dataclass(frozen=True)
based solely on 'name' because we set compare=False/hash=False on 'fn'.
"""
name: str
fn: Callable[[torch.nn.Parameter, str], bool] = field(compare=False, hash=False)
def __call__(self, param: torch.nn.Parameter, name: str) -> bool:
return self.fn(param, name)
@dataclass(frozen=True)
class ParamKey:
"""Key to group parameters by. All such grouped parameters can share an
optimizer config specification.
"""
name: Union[str, Tuple[str]] = field(default_factory=tuple)
"""Parameter name(s), will use unix filesystem path syntax for matching."""
attr: Union[str, Tuple[str]] = field(default_factory=tuple)
"""Parameter attribute(s)."""
predicate: Union[ParamPredicate, Tuple[ParamPredicate]] = field(default_factory=tuple)
"""Predicate(s) to match parameters by. If multiple predicates are provided, any must match."""
with_name_predicate: Union[ParamWithNamePredicate, Tuple[ParamWithNamePredicate]] = field(default_factory=tuple)
"""
Predicate(s) to match parameters with their name. If multiple predicates are provided,
any must match. This is useful if you need to filter out some parameters from an otherwise
positive match by their name.
"""
def matches(self, param: torch.nn.Parameter, param_name: str) -> bool:
"""Returns true if passed-in parameter (with name) matches `param_key`.
Args:
param (torch.nn.Parameter): Handle to parameter object.
param_name (str): Name of parameter in underlying PyTorch module.
Returns:
bool: True if parameter matches passed-in param_key.
"""
if isinstance(self.name, str):
target_names = [self.name]
else:
target_names = list(self.name)
for target_name in target_names:
if fnmatch.fnmatch(param_name, target_name):
return True
if isinstance(self.attr, str):
target_attrs = [self.attr]
else:
target_attrs = list(self.attr)
for target_attr in target_attrs:
if getattr(param, target_attr, False):
return True
if isinstance(self.predicate, ParamPredicate):
if self.predicate(param):
return True
else:
for predicate in self.predicate:
if predicate(param):
return True
if isinstance(self.with_name_predicate, ParamWithNamePredicate):
if self.with_name_predicate(param, param_name):
return True
else:
for predicate in self.with_name_predicate:
if predicate(param, param_name):
return True
return False