import re
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from collections import ChainMap
import logging
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from ..distributed.parallel_state import get_parallel_state
from ...optimizer.muon import Muon
logger = logging.getLogger(__name__)
class MultiOptimizer(Optimizer, Stateful):
"""
A container that handles multiple optimizers (for ep and non-ep parameters when ep+fsdp2 is enabled)
Mapping of name -> torch.optim.Optimizer with convenience methods.
Compatible with torch.distributed.checkpoint optimizer APIs that accept a Mapping.
This class is needed for EP+FSDP2 case because EP and non-EP param have different FSDP sharding dimension (dim-0 vs. dim-1)
For comparison, EP+FSDP1 also shards EP parameters along dim-0 for FSDP, so it can use the default optimizer class.
"""
def __init__(
self,
root_model: nn.Module,
optimizers: dict,
key_names: list[str],
):
self.model = root_model
self.optimizers_dict = optimizers
self._is_multi_optimizer: bool = True
self.key_names = key_names
@property
def state(self):
"""
Returns a read-only aggregated view of the states from all sub-optimizers.
Uses collections.ChainMap to combine the state dictionaries without copying,
providing efficient and unified access while preserving immutability at this level.
"""
state_dicts = [opt.state for opt in self.optimizers_dict.values()]
return ChainMap(*state_dicts)
@property
def param_groups(self):
"""
Returns a flat list aggregating all parameter groups from every sub-optimizer.
This allows the composite optimizer to expose a unified interface compatible
with standard PyTorch optimizer expectations (e.g., for learning rate schedulers).
"""
all_groups = []
for opt in self.optimizers_dict.values():
all_groups.extend(opt.param_groups)
return all_groups
def step(self) -> None:
for opt in self.optimizers_dict.values():
opt.step()
def zero_grad(self) -> None:
for opt in self.optimizers_dict.values():
opt.zero_grad()
def state_dict(
self,
) -> Dict[str, Any]:
merged: Dict[str, Any] = {}
for name in self.key_names:
opt = self.optimizers_dict.get(name)
sd = get_optimizer_state_dict(self.model, opt, options=StateDictOptions(flatten_optimizer_state_dict=True))
overlap = set(merged.keys()) & set(sd.keys())
if overlap:
raise KeyError(
f"Key clash detected while merging state dict for optimizer '{name}': {', '.join(sorted(overlap))}"
)
else:
logger.info("No clashes when merging MultiOptimizer state dicts")
merged.update(sd)
return merged
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
for name in self.key_names:
opt = self.optimizers_dict.get(name)
set_optimizer_state_dict(
self.model,
opt,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
def register_step_pre_hook(self, hook):
return [opt.register_step_pre_hook(hook) for opt in self.optimizers_dict.values()]
def __len__(self) -> int:
return len(self.optimizers_dict)
def __repr__(self) -> str:
return self.optimizers_dict.__repr__()
def _make_param_groups_for_subset(
model: "nn.Module",
params: Iterable[torch.nn.Parameter],
weight_decay: float,
no_decay_modules: Optional[List[str]] = None,
no_decay_params: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
decay_param_names = set(get_parameter_names(model, no_decay_modules, no_decay_params))
name_by_param = {p: n for n, p in model.named_parameters()}
params = [p for p in params if p.requires_grad]
decayed = [p for p in params if name_by_param.get(p) in decay_param_names]
undecayed = [p for p in params if name_by_param.get(p) not in decay_param_names]
groups: List[Dict[str, Any]] = []
if decayed:
groups.append({"params": decayed, "weight_decay": weight_decay})
if undecayed:
groups.append({"params": undecayed, "weight_decay": 0.0})
return groups
def _is_muon_eligible(name: str, param: torch.nn.Parameter) -> bool:
is_2d_matrix = len(param.shape) == 2
return (
not name.endswith(".bias")
and "embedding" not in name
and "output_layer" not in name
and is_2d_matrix
)
def _mark_muon_param_groups(
model: "nn.Module",
param_groups: Sequence[Dict[str, Any]],
) -> List[Dict[str, Any]]:
name_by_param = {p: n for n, p in model.named_parameters()}
marked_groups: List[Dict[str, Any]] = []
for group in param_groups:
params = group.get("params", [])
muon_params = []
fallback_params = []
for p in params:
if not p.requires_grad:
continue
param_name = name_by_param.get(p, "")
if _is_muon_eligible(param_name, p):
muon_params.append(p)
else:
fallback_params.append(p)
group_base = {k: v for k, v in group.items() if k != "params"}
if muon_params:
marked_groups.append({**group_base, "params": muon_params, "use_muon": True})
if fallback_params:
marked_groups.append({**group_base, "params": fallback_params, "use_muon": False})
return marked_groups
def get_parameter_names(model, forbidden_layer_types, forbidden_param_names):
forbidden_layer_types = [] if forbidden_layer_types is None else forbidden_layer_types
forbidden_param_names = [] if forbidden_param_names is None else forbidden_param_names
result = []
for name, child in model.named_children():
child_params = get_parameter_names(child, forbidden_layer_types, forbidden_param_names)
result += [
f"{name}.{n}"
for n in child_params
if child.__class__.__name__ not in forbidden_layer_types
and not any(forbidden in f"{name}.{n}".lower() for forbidden in forbidden_param_names)
]
result += [
k for k in model._parameters.keys() if not any(forbidden in k.lower() for forbidden in forbidden_param_names)
]
return result
def build_optimizer(
model: "nn.Module",
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
eps: float = 1e-8,
weight_decay: float = 1e-2,
fused: bool = False,
optimizer_type: str = "adamw",
param_groups: Optional[Sequence[Dict[str, Any]]] = None,
no_decay_modules: Optional[List[str]] = None,
no_decay_params: Optional[List[str]] = None,
matched_adamw_rms: float = 0.2,
muon_momentum: float = 0.95,
ns_steps: int = 5,
lr_scaling_plan: Optional[List] = None,
) -> "torch.optim.Optimizer":
ps = get_parallel_state()
if ps.get_ep_group_size() > 1:
logger.info("Building EP+FSDP2 optimizer")
return build_ep_fsdp2_optimizer(
model,
lr,
betas,
eps,
weight_decay,
fused,
optimizer_type,
param_groups,
no_decay_modules,
no_decay_params,
matched_adamw_rms=matched_adamw_rms,
muon_momentum=muon_momentum,
ns_steps=ns_steps,
)
if lr_scaling_plan:
decay_param_names = get_parameter_names(model, no_decay_modules, no_decay_params)
param_groups = get_param_groups(model, weight_decay, lr, lr_scaling_plan, decay_param_names)
elif param_groups is None:
decay_param_names = get_parameter_names(model, no_decay_modules, no_decay_params)
param_groups = [
{
"params": [p for n, p in model.named_parameters() if n in decay_param_names and p.requires_grad],
"weight_decay": weight_decay,
},
]
no_decay_parameters, no_decay_parameter_names = [], []
for n, p in model.named_parameters():
if n not in decay_param_names and p.requires_grad:
no_decay_parameter_names.append(n)
no_decay_parameters.append(p)
if len(no_decay_parameters) > 0:
logger.info(f"Parameters without weight decay: {no_decay_parameter_names}")
param_groups.append({"params": no_decay_parameters, "weight_decay": 0.0})
if optimizer_type == "muon":
param_groups = _mark_muon_param_groups(model, param_groups)
logger.info(f"Muon parameter groups: {param_groups}")
optim = Muon(
param_groups,
lr=lr,
weight_decay=weight_decay,
matched_adamw_rms=matched_adamw_rms,
momentum=muon_momentum,
ns_steps=ns_steps,
adamw_betas=betas,
adamw_eps=eps,
)
elif optimizer_type == "adamw":
foreach = not fused
fused = fused
optim = AdamW(param_groups, lr, betas, eps, weight_decay, fused=fused, foreach=foreach)
else:
raise ValueError("Only adamw and muon are supported as optimizers.")
return optim
def group_params_by_lr_ratio(param_names: List[str], params: List[torch.Tensor],
lr_scaling_plan: List[Dict], base_lr: float) -> List[Dict]:
"""按lr_scaling_plan对参数进行分组"""
groups_dict = {}
if not lr_scaling_plan:
lr_scaling_plan = []
for name, param in zip(param_names, params):
lr_ratio = 1.0
for pattern_config in lr_scaling_plan:
if re.match(pattern_config.match, name):
lr_ratio = pattern_config.scale
break
lr = base_lr * lr_ratio
if lr not in groups_dict:
groups_dict[lr] = {"params": [], "lr": lr}
groups_dict[lr]["params"].append(param)
return list(groups_dict.values())
def get_param_groups(
model: torch.nn.Module,
weight_decay: float,
base_lr: float,
lr_scaling_plan: List[Dict] = None,
decay_param_names: List[str] = None,
) -> List[Dict]:
"""按正则表达式分组参数,先按weight_decay,再按lr_ratio分组"""
no_decay_parameters, no_decay_parameter_names = [], []
decay_parameters, decay_parameter_names = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if n in decay_param_names:
decay_parameter_names.append(n)
decay_parameters.append(p)
else:
no_decay_parameter_names.append(n)
no_decay_parameters.append(p)
if len(no_decay_parameters) > 0:
logger.info(f"Parameters without weight decay: {no_decay_parameter_names}")
decay_groups = group_params_by_lr_ratio(decay_parameter_names, decay_parameters, lr_scaling_plan, base_lr)
no_decay_groups = group_params_by_lr_ratio(no_decay_parameter_names, no_decay_parameters, lr_scaling_plan, base_lr)
param_groups = []
for group in decay_groups:
group["weight_decay"] = weight_decay
param_groups.append(group)
for group in no_decay_groups:
group["weight_decay"] = 0.0
param_groups.append(group)
return param_groups
def build_ep_fsdp2_optimizer(
model: "nn.Module",
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.95),
eps: float = 1e-8,
weight_decay: float = 1e-2,
fused: bool = False,
optimizer_type: str = "adamw",
param_groups: Optional[List[Dict[str, Any]]] = None,
no_decay_modules: Optional[List[str]] = None,
no_decay_params: Optional[List[str]] = None,
matched_adamw_rms: float = 0.2,
muon_momentum: float = 0.95,
ns_steps: int = 5,
):
"""
Build a MultiOptimizer instance when model is parallelized with EP+FSDP2
If param_groups provided, it can be a list of dicts with arbitrary parameter groups:
- Example: [{"params": params1, "lr": lr1},
{"params": params2, "lr": lr2},
{"params": params3, "lr": lr3}]
- Each group's params are automatically split into EP and non-EP based on DTensor mesh
- Custom learning rates and other optimizer settings are preserved per group
"""
ep_groups: List[Dict[str, Any]] = []
non_ep_groups: List[Dict[str, Any]] = []
if param_groups is not None:
for group_config in param_groups:
group_lr = group_config.get("lr", lr)
group_params = group_config["params"]
group_ep_params: List[torch.nn.Parameter] = []
group_non_ep_params: List[torch.nn.Parameter] = []
for p in group_params:
if not p.requires_grad:
continue
if DTensor is not None and isinstance(p, DTensor):
mesh = getattr(p, "device_mesh", None)
names = getattr(mesh, "mesh_dim_names", []) if mesh is not None else []
if "efsdp" in names:
group_ep_params.append(p)
continue
group_non_ep_params.append(p)
if group_ep_params:
group_ep_subgroups = _make_param_groups_for_subset(
model, group_ep_params, weight_decay, no_decay_modules, no_decay_params
)
for subgroup in group_ep_subgroups:
subgroup["lr"] = group_lr
for key, value in group_config.items():
if key not in ["params", "lr", "weight_decay"]:
subgroup[key] = value
ep_groups.extend(group_ep_subgroups)
if group_non_ep_params:
group_non_ep_subgroups = _make_param_groups_for_subset(
model, group_non_ep_params, weight_decay, no_decay_modules, no_decay_params
)
for subgroup in group_non_ep_subgroups:
subgroup["lr"] = group_lr
for key, value in group_config.items():
if key not in ["params", "lr", "weight_decay"]:
subgroup[key] = value
non_ep_groups.extend(group_non_ep_subgroups)
else:
ep_params: List[torch.nn.Parameter] = []
non_ep_params: List[torch.nn.Parameter] = []
for _, p in model.named_parameters():
if not p.requires_grad:
continue
if DTensor is not None and isinstance(p, DTensor):
mesh = getattr(p, "device_mesh", None)
names = getattr(mesh, "mesh_dim_names", []) if mesh is not None else []
if "efsdp" in names:
ep_params.append(p)
continue
non_ep_params.append(p)
ep_groups = _make_param_groups_for_subset(model, ep_params, weight_decay, no_decay_modules, no_decay_params)
non_ep_groups = _make_param_groups_for_subset(
model, non_ep_params, weight_decay, no_decay_modules, no_decay_params
)
def _build(groups: Sequence[Dict[str, Any]]) -> Optimizer:
foreach = not fused
fused_ = fused
if optimizer_type == "muon":
groups = _mark_muon_param_groups(model, groups)
return Muon(
groups,
lr=lr,
weight_decay=weight_decay,
matched_adamw_rms=matched_adamw_rms,
momentum=muon_momentum,
ns_steps=ns_steps,
adamw_betas=betas,
adamw_eps=eps,
)
elif optimizer_type == "adamw":
return AdamW(groups, lr, betas, eps, weight_decay, fused=fused_, foreach=foreach)
else:
raise ValueError("Only adamw and muon are supported as optimizers.")
optimizer_dict: Dict[str, Optimizer] = {}
if ep_groups:
optimizer_dict["ep"] = _build(ep_groups)
if non_ep_groups:
optimizer_dict["non_ep"] = _build(non_ep_groups)
model._ep_param_groups = {
"ep": [p for g in ep_groups for p in g.get("params", [])] if ep_groups else [],
"non_ep": [p for g in non_ep_groups for p in g.get("params", [])] if non_ep_groups else [],
}
key_names = list(optimizer_dict.keys())
multi_opt = MultiOptimizer(model, optimizer_dict, key_names=key_names)
return multi_opt