"""
Optimizer factory for unified construction of single optimizer and EP+FSDP2 multi-optimizer, support Adamw and Muon now.
"""
import torch
from torch import nn
from typing import Any, Dict, List, Iterable, Optional, Sequence, Tuple
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from transformers.utils import is_torch_npu_available
from mindspeed_llm.fsdp2.optim.muon import Muon
from mindspeed_llm.fsdp2.utils.logging import get_logger
logger = get_logger(__name__)
class MultiOptimizer(Optimizer, Stateful):
def __init__(self, root_model: nn.Module, optimizers: dict, key_names: list[str]):
self.model = root_model
self.optimizers_dict = optimizers
self._is_multi_optimizer: bool = True
self.key_names = key_names
def step(self) -> None:
for opt in self.optimizers_dict.values():
opt.step()
def zero_grad(self) -> None:
for opt in self.optimizers_dict.values():
opt.zero_grad()
def state_dict(self) -> Dict[str, Any]:
merged: Dict[str, Any] = {}
for name in self.key_names:
opt = self.optimizers_dict.get(name)
sd = get_optimizer_state_dict(self.model, opt, options=StateDictOptions(flatten_optimizer_state_dict=True))
overlap = set(merged.keys()) & set(sd.keys())
if overlap:
raise KeyError(f"Key clash detected for optimizer '{name}': {', '.join(sorted(overlap))}")
merged.update(sd)
return merged
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
for name in self.key_names:
opt = self.optimizers_dict.get(name)
set_optimizer_state_dict(
self.model,
opt,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
def __len__(self) -> int:
return len(self.optimizers_dict)
class OptimizerFactory:
"""Optimizer factory class for unified construction of single optimizer/EP+FSDP2 multi-optimizer"""
@staticmethod
def _split_muon_and_adamw_params(
model: nn.Module, all_params: Optional[List[nn.Parameter]] = None
) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
"""
Split parameters into Muon parameters and AdamW parameters
- Muon params: 2D parameters that aren't embeddings or lm_head (requires_grad=True)
- AdamW params: All other trainable parameters
Args:
model: Model to be optimized (for parameter name matching)
all_params: Optional, only split parameters in this list (for EP+FSDP2 subset splitting)
"""
muon_params, adamw_params = [], []
name_to_param = dict(model.named_parameters())
if all_params is not None:
for param in all_params:
param_name = next((n for n, p in name_to_param.items() if p is param), None)
if param_name is None or not param.requires_grad:
adamw_params.append(param)
continue
if param.ndim == 2 and "embed" not in param_name and "lm_head" not in param_name:
muon_params.append(param)
else:
adamw_params.append(param)
else:
for name, param in model.named_parameters():
if param.requires_grad:
if param.ndim == 2 and "embed" not in name and "lm_head" not in name:
muon_params.append(param)
else:
adamw_params.append(param)
return muon_params, adamw_params
@staticmethod
def create(
model: torch.nn.Module,
ep_size: int,
lr: float,
optimizer_type: str,
weight_decay: float,
betas: Tuple[float, float],
adam_epsilon: float,
fused: bool = False,
param_groups: Optional[Sequence[Dict[str, Any]]] = None,
no_decay_modules: Optional[List[str]] = None,
no_decay_params: Optional[List[str]] = None,
) -> torch.optim.Optimizer:
"""
Build optimizer instance
Args:
model: Model to be optimized
ep_size: Expert parallel size
lr: Base learning rate
optimizer_type: Type of optimizer
weight_decay: Weight decay coefficient
betas: Tuple of AdamW beta1 and beta2 parameters
adam_epsilon: AdamW epsilon parameter
fused: Whether to enable fused optimizer kernel
param_groups: Custom parameter groups
no_decay_modules: List of module names that do not require weight decay
no_decay_params: List of parameter names that do not require weight decay
"""
if ep_size > 1:
logger.info_rank0("Building EP+FSDP2 optimizer (MultiOptimizer)")
return OptimizerFactory._build_ep_fsdp2_optimizer(
model=model,
lr=lr,
betas=betas,
eps=adam_epsilon,
weight_decay=weight_decay,
fused=fused,
optimizer_type=optimizer_type.lower(),
param_groups=param_groups,
no_decay_modules=no_decay_modules,
no_decay_params=no_decay_params,
)
single_optimizer = OptimizerFactory._create_single_optimizer(
model=model,
lr=lr,
optimizer_type=optimizer_type,
weight_decay=weight_decay,
betas=betas,
adam_epsilon=adam_epsilon,
fused=fused,
param_groups=param_groups,
no_decay_modules=no_decay_modules,
no_decay_params=no_decay_params,
)
return single_optimizer
@staticmethod
def _create_single_optimizer(
model,
lr,
optimizer_type,
weight_decay,
betas,
adam_epsilon,
fused,
param_groups,
no_decay_modules,
no_decay_params,
):
"""
Create single optimizer instance.
"""
optimizer_type = optimizer_type.lower()
foreach = not fused
muon_params = None
adamw_params = None
if optimizer_type == "muon":
muon_params, adamw_params = OptimizerFactory._split_muon_and_adamw_params(model)
logger.info_rank0(
f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params."
)
param_groups = []
else:
if param_groups is None:
decay_param_names = OptimizerFactory._get_parameter_names(model, no_decay_modules, no_decay_params)
decay_dtensors = []
decay_tensors = []
no_decay_dtensors = []
no_decay_tensors = []
no_decay_param_names = []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
is_dtensor = isinstance(p, DTensor)
if n in decay_param_names:
if is_dtensor:
decay_dtensors.append(p)
else:
decay_tensors.append(p)
else:
no_decay_param_names.append(n)
if is_dtensor:
no_decay_dtensors.append(p)
else:
no_decay_tensors.append(p)
param_groups = []
if decay_dtensors:
param_groups.append({"params": decay_dtensors, "weight_decay": weight_decay})
if decay_tensors:
param_groups.append({"params": decay_tensors, "weight_decay": weight_decay})
if no_decay_dtensors:
logger.debug_rank0(f"DTensor parameters without weight decay: {len(no_decay_dtensors)} tensors")
param_groups.append({"params": no_decay_dtensors, "weight_decay": 0.0})
if no_decay_tensors:
logger.debug_rank0(
f"Native Tensor parameters without weight decay: {len(no_decay_tensors)} tensors"
)
param_groups.append({"params": no_decay_tensors, "weight_decay": 0.0})
optimizer = OptimizerFactory._create_optimizer_instance(
optimizer_type=optimizer_type,
param_groups=param_groups,
lr=lr,
betas=betas,
eps=adam_epsilon,
weight_decay=weight_decay,
fused=fused,
foreach=foreach,
muon_params=muon_params,
adamw_params=adamw_params,
)
logger.info_rank0(
f"Created single optimizer {optimizer_type} | lr={lr}, weight_decay={weight_decay} | Total Param Groups: {len(param_groups) if param_groups else 'N/A'}"
)
return optimizer
@staticmethod
def _create_optimizer_instance(
optimizer_type: str,
param_groups: Sequence[Dict[str, Any]],
lr: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
fused: bool = False,
foreach: bool = False,
muon_params: Optional[List[nn.Parameter]] = None,
adamw_params: Optional[List[nn.Parameter]] = None,
) -> Optimizer:
"""
Unified optimizer instantiation entry point, add branches here for newly extended optimizers
"""
if optimizer_type == "adamw":
return AdamW(
param_groups, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, fused=fused, foreach=foreach
)
elif optimizer_type == "muon":
if muon_params is None or adamw_params is None:
raise ValueError("muon_params and adamw_params must be provided for Muon optimizer.")
return Muon(
lr=lr,
wd=weight_decay,
muon_params=muon_params,
adamw_params=adamw_params,
adamw_betas=betas,
adamw_eps=eps,
)
else:
raise ValueError(f"Unsupported optimizer type: {optimizer_type}, supported types: [adamw, muon]")
@staticmethod
def _get_parameter_names(model, forbidden_layer_types, forbidden_param_names):
forbidden_layer_types = forbidden_layer_types or []
forbidden_param_names = forbidden_param_names or []
result = []
for name, child in model.named_children():
child_params = OptimizerFactory._get_parameter_names(child, forbidden_layer_types, forbidden_param_names)
result += [
f"{name}.{n}"
for n in child_params
if child.__class__.__name__ not in forbidden_layer_types
and not any(f in f"{name}.{n}".lower() for f in forbidden_param_names)
]
result += [k for k in model._parameters.keys() if not any(f in k.lower() for f in forbidden_param_names)]
return result
@staticmethod
def _make_param_groups_for_subset(
model: nn.Module,
params: Iterable[torch.nn.Parameter],
weight_decay: float,
no_decay_modules: Optional[List[str]],
no_decay_params: Optional[List[str]],
) -> List[Dict[str, Any]]:
decay_param_names = set(OptimizerFactory._get_parameter_names(model, no_decay_modules, no_decay_params))
name_by_param = {p: n for n, p in model.named_parameters()}
params = [p for p in params if p.requires_grad]
decayed = [p for p in params if name_by_param.get(p) in decay_param_names]
undecayed = [p for p in params if name_by_param.get(p) not in decay_param_names]
groups = []
if decayed:
groups.append({"params": decayed, "weight_decay": weight_decay})
if undecayed:
groups.append({"params": undecayed, "weight_decay": 0.0})
return groups
@staticmethod
def _build_ep_fsdp2_optimizer(
model: nn.Module,
lr: float,
betas: Tuple[float, float],
eps: float,
weight_decay: float,
fused: bool,
optimizer_type: str,
param_groups: Optional[List[Dict[str, Any]]],
no_decay_modules: Optional[List[str]],
no_decay_params: Optional[List[str]],
) -> MultiOptimizer:
ep_groups: List[Dict[str, Any]] = []
non_ep_groups: List[Dict[str, Any]] = []
if param_groups is not None:
if not isinstance(param_groups, list):
raise ValueError("param_groups must be a list")
for group_config in param_groups:
if "params" not in group_config:
raise ValueError(f"Group missing 'params' key: {group_config}")
group_lr = group_config.get("lr", lr)
group_params = group_config["params"]
group_ep_params, group_non_ep_params = [], []
for p in group_params:
if not p.requires_grad:
continue
if DTensor is not None and isinstance(p, DTensor):
mesh = getattr(p, "device_mesh", None)
names = getattr(mesh, "mesh_dim_names", []) if mesh is not None else []
if "efsdp" in names:
group_ep_params.append(p)
continue
group_non_ep_params.append(p)
if group_ep_params:
ep_subgroups = OptimizerFactory._make_param_groups_for_subset(
model, group_ep_params, weight_decay, no_decay_modules, no_decay_params
)
for subgroup in ep_subgroups:
subgroup["lr"] = group_lr
for k, v in group_config.items():
if k not in ["params", "lr", "weight_decay"]:
subgroup[k] = v
ep_groups.extend(ep_subgroups)
if group_non_ep_params:
non_ep_subgroups = OptimizerFactory._make_param_groups_for_subset(
model, group_non_ep_params, weight_decay, no_decay_modules, no_decay_params
)
for subgroup in non_ep_subgroups:
subgroup["lr"] = group_lr
for k, v in group_config.items():
if k not in ["params", "lr", "weight_decay"]:
subgroup[k] = v
non_ep_groups.extend(non_ep_subgroups)
else:
ep_params, non_ep_params = [], []
for _, p in model.named_parameters():
if not p.requires_grad:
continue
if DTensor is not None and isinstance(p, DTensor):
mesh = getattr(p, "device_mesh", None)
names = getattr(mesh, "mesh_dim_names", []) if mesh is not None else []
if "efsdp" in names:
ep_params.append(p)
continue
non_ep_params.append(p)
ep_groups = OptimizerFactory._make_param_groups_for_subset(
model, ep_params, weight_decay, no_decay_modules, no_decay_params
)
non_ep_groups = OptimizerFactory._make_param_groups_for_subset(
model, non_ep_params, weight_decay, no_decay_modules, no_decay_params
)
def _build_optimizer(groups: Sequence[Dict[str, Any]]) -> Optimizer:
foreach = False if is_torch_npu_available() else (not fused)
fused_ = False if is_torch_npu_available() else fused
muon_params = None
adamw_params = None
param_groups = groups
if optimizer_type == "muon":
all_params = [p for g in groups for p in g.get("params", [])]
muon_params, adamw_params = OptimizerFactory._split_muon_and_adamw_params(model, all_params)
logger.info_rank0(
f"EP+FSDP2 Muon sub-optimizer: {len(muon_params)} Muon params, {len(adamw_params)} AdamW params."
)
param_groups = []
return OptimizerFactory._create_optimizer_instance(
optimizer_type=optimizer_type,
param_groups=param_groups,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
fused=fused_,
foreach=foreach,
muon_params=muon_params,
adamw_params=adamw_params,
)
optimizer_dict = {}
if ep_groups:
optimizer_dict["ep"] = _build_optimizer(ep_groups)
if non_ep_groups:
optimizer_dict["non_ep"] = _build_optimizer(non_ep_groups)
model._ep_param_groups = {
"ep": [p for g in ep_groups for p in g.get("params", [])] if ep_groups else [],
"non_ep": [p for g in non_ep_groups for p in g.get("params", [])] if non_ep_groups else [],
}
multi_opt = MultiOptimizer(root_model=model, optimizers=optimizer_dict, key_names=list(optimizer_dict.keys()))
logger.info_rank0(f"Created EP+FSDP2 MultiOptimizer {optimizer_type} | lr={lr}, weight_decay={weight_decay}")
return multi_opt