Muon 优化器特性
在大规模语言模型的训练中,优化器的选择对收敛速度和最终性能有着重要影响。传统的 Adam/AdamW 优化器虽然通用性强,但将参数视为独立的一维向量进行更新,忽略了矩阵参数的结构信息。Muon 优化器通过引入动量正交化(Momentum Orthogonalization)技术,针对 2D/3D 矩阵参数实现了更高效的梯度下降策略,在大模型训练中展现出更快的收敛速度。
实现原理
torchtitan-npu 采用了 Muon + AdamW 混合优化器策略,核心代码定义在 torchtitan_npu/patches/optimizer/muon_optimizer.py 。
参数分配策略
模型参数根据维度和语义自动路由到不同的优化器:
- 2D 参数(如 Linear 层的权重矩阵)→ 使用 Muon 优化器
- 例外:名称中包含
embed、lm_head、output的 2D 参数 → 使用 AdamW(此类层不适合正交化更新)
- 例外:名称中包含
- 3D 参数(如 MoE 专家权重)→ 使用 Muon 优化器
- 1D 及其他参数(如偏置、LayerNorm 参数)→ 使用 AdamW 优化器
两种优化器被统一封装在 MuonHybridOptimizersContainer 中,对外提供与标准优化器一致的 step()、zero_grad()、state_dict() / load_state_dict() 接口。
Newton-Schulz 正交化
Muon 的核心操作是对动量梯度进行正交化(LMO,Low-orthogonal Matrix Operation),通过 Newton-Schulz 迭代算法实现矩阵的零次幂近似(即投影到正交矩阵流形上):
Xk+1=a⋅Xk+b⋅XkXkTXk+c⋅(XkXkT)2XkX_{k+1} = a \cdot X_k + b \cdot X_k X_k^T X_k + c \cdot (X_k X_k^T)^2 X_k
其中主系数为 (a,b,c)=(3.4445,−4.7750,2.0315)(a, b, c) = (3.4445, -4.7750, 2.0315),在零点处具有最大斜率,确保快速收敛。
混合 Newton-Schulz(hybrid_ns)
当启用 muon_hybrid_ns = true 时,采用 DeepSeek-V4 提出的混合迭代策略:
- 前 8 步:使用主系数 (3.4445,−4.7750,2.0315)(3.4445, -4.7750, 2.0315)
- 后 2 步:切换到次系数 (2.0,−1.5,0.5)(2.0, -1.5, 0.5)
此策略在保持收敛速度的同时,提升了正交化结果的数值稳定性。
学习率调整模式
Muon 优化器支持两种学习率调整模式(通过 muon_adjust_lr_fn 配置),其核心区别在于如何根据矩阵形状调整学习率,以及是否需要独立的超参数调优。
| 模式 | 调整公式 | 说明 |
|---|---|---|
original |
γ←γ⋅max(1,A/B)\gamma \leftarrow \gamma \cdot \sqrt{\max(1, A/B)} | Keller Jordan 原始实现,根据矩阵宽高比调整 |
match_rms_adamw |
γ←0.18⋅γ⋅max(A,B)\gamma \leftarrow 0.18 \cdot \gamma \cdot \sqrt{\max(A, B)} | Moonshot 实现,直接复用 AdamW 的 lr 和 weight_decay |
original 模式
该模式源自 Muon 创始人 Keller Jordan 的原始实现。调整公式为:
γadjusted=γ×max(1,AB)\gamma_{\text{adjusted}} = \gamma \times \sqrt{\max\left(1, \frac{A}{B}\right)}
其中 AA 和 BB 是矩阵的两个维度。这个调整的目的是:让正交化后的梯度更新在不同形状的矩形矩阵上具有一致的 RMS(Root Mean Square)。
- 当 A≤BA \le B(宽矩阵,如 FFN 中的中间层)时,系数为 1,不做额外调整
- 当 A>BA > B(高矩阵,如输出层)时,按 A/B\sqrt{A/B} 缩放
由于调整幅度较大,通常需要单独为 Muon 调优学习率(即配置 muon_lr),一般来说可以将 AdamW 的学习率放大 10 倍来作为 Muon 的学习率。
match_rms_adamw 模式
该模式来自 Moonshot 团队的论文 Muon is Scalable for LLM Training。调整公式为:
γadjusted=0.18×γ×max(A,B)\gamma_{\text{adjusted}} = 0.18 \times \gamma \times \sqrt{\max(A, B)}
注:当前实现使用系数 0.18(与 DeepSeek-V4 一致),Moonshot 原始论文使用 0.2。
这个模式的设计目标是:让 Muon 可以直接复用已经为 AdamW 调优好的学习率和权重衰减超参数,无需额外的超参数搜索。
模式选择建议
- 使用
match_rms_adamw(默认):如果你已经为 AdamW 调优好了超参数,希望直接尝试 Muon 而不想重新调参 - 使用
original:如果你愿意投入时间单独调优 Muon 的学习率,追求可能的更好收敛效果
分布式通信机制
DistributedMuon 优化器针对不同的并行策略实现了三种参数更新路径:
| 并行策略 | 通信方式 | 适用场景 |
|---|---|---|
| FSDP | all_to_all 双向通信 | FSDP 分片的 2D 参数 |
| DDP | all_gather 通信 | DDP 复制的 2D 参数 |
| Expert/MoE | 本地 LMO(无跨卡通信) | 3D 专家权重(已通过 EP 切分) |
FSDP 路径(step_fsdp):采用 owner-per-bucket 方案,每个 rank 负责一组参数的完整 LMO 计算。通过两轮 all_to_all 通信:第一轮将各 rank 的梯度分片汇聚到参数 owner,owner 完成正交化后,第二轮将更新分片发回各 rank。
DDP 路径(step_ddp):每个 rank 先对本地负责的参数计算 LMO 更新,然后通过 all_gather 将更新广播给所有 rank。支持与 TP 的联合使用,在 LMO 计算前先通过 all_gather 收集 TP 分片。
Expert 路径(step_experts):专家权重已通过 Expert Parallel 切分到各 rank,每个 rank 独立对本地专家参数执行 LMO,无需跨卡通信。
Swap Optimizer(推荐)
对于显存受限的场景,Muon 优化器支持通过 swap_optimizer = true 启用优化器状态 CPU 卸载,将 Muon 的 momentum_buffer 和 AdamW 的 exp_avg、exp_avg_sq 卸载到 CPU,仅在优化器 step 期间按需换入 NPU。
swap_merge_buckets 配置
swap_merge_buckets 控制 Muon momentum_buffer H2D/D2H 操作的合并粒度。Muon 参数被分成多个 bucket(由 FSDP all_to_all 通信决定),每个 bucket 独立执行一次 H2D(step 前)和 D2H(step 后)。swap_merge_buckets 决定将多少个连续 bucket 的 H2D/D2H 合并为一次 stream 操作。
配置选项
在训练任务的 TOML 配置文件(例如 torchtitan_npu/models/deepseek_v4/train_configs/deepseek_v4_285b_4layers_debug_muon.toml,或实际启动训练时 --job.config_file 所指向的路径)中,找到对应的 [optimizer] 节,并添加以下配置以启用 Muon 优化器:
| 配置项 | 类型 | 默认值 | 说明 |
|---|---|---|---|
name |
str | "AdamW" | 优化器类型,设置为 "Muon" 启用本特性 |
lr |
float | — | 基础学习率,AdamW 部分直接使用;Muon 部分取决于 muon_adjust_lr_fn |
muon_lr |
float | None | Muon 专用学习率。仅当 muon_adjust_lr_fn = "original" 时生效;若不设置则回退到 lr;当 muon_adjust_lr_fn = "match_rms_adamw" 时此值被忽略 |
muon_momentum |
float | 0.95 | Muon 的动量因子 |
muon_enable_nesterov |
bool | True | 是否启用 Nesterov 动量 |
muon_ns_steps |
int | 5 | Newton-Schulz 正交化迭代步数,影响正交化精度与计算开销。值越大正交化越精确,但计算量也越大 |
muon_adjust_lr_fn |
str | "match_rms_adamw" | 学习率调整模式:"original" 或 "match_rms_adamw" |
muon_hybrid_ns |
bool | False | 是否启用混合 Newton-Schulz 迭代(前 8 步用主系数,后续步用次系数) |
swap_optimizer |
bool | False | 是否启用 Swap Optimizer,将优化器状态卸载到 CPU 并异步换入换出以节省显存 |
swap_merge_buckets |
int | 1 | Swap H2D/D2H 合并桶数。值越大 stream 同步开销越低但峰值显存略增。推荐 4~16 |
配置示例
示例 1:match_rms_adamw 模式
最简配置,直接复用 AdamW 超参,无需额外调参即可使用 Muon:
[job]
custom_config_module = "torchtitan_npu.config.custom_config" # 使能本代码仓的自定义配置
[optimizer]
name = "Muon" # 使用 Muon 混合优化器
lr = 2.2e-4 # 基础学习率(Muon 和 AdamW 共用)
weight_decay = 0.1
muon_momentum = 0.95 # Muon 动量因子
muon_enable_nesterov = true # 启用 Nesterov 动量
muon_ns_steps = 10 # 正交化步数
muon_adjust_lr_fn = "match_rms_adamw" # 复用 AdamW 超参(默认值,可省略)
muon_hybrid_ns = true # 启用混合 NS
示例 2:original 模式
需要为 Muon 单独设置学习率,适合愿意投入调参资源的场景:
[job]
custom_config_module = "torchtitan_npu.config.custom_config" # 使能本代码仓的自定义配置
[optimizer]
name = "Muon" # 使用 Muon 混合优化器
lr = 3e-4 # AdamW 部分的学习率
muon_lr = 3e-3 # Muon 专用学习率(通常为 AdamW lr 的 10 倍)
weight_decay = 0.01
muon_momentum = 0.95 # Muon 动量因子
muon_enable_nesterov = true # 启用 Nesterov 动量
muon_ns_steps = 5 # 正交化步数
muon_adjust_lr_fn = "original" # 使用独立的 lr 调度器
示例 3:启用 Swap Optimizer
在显存受限的场景下,将优化器状态卸载到 CPU,通过异步 H2D/D2H 流水线减少性能损失:
[job]
custom_config_module = "torchtitan_npu.config.custom_config" # 使能本代码仓的自定义配置
[optimizer]
name = "Muon" # 使用 Muon 混合优化器
lr = 2.2e-4
weight_decay = 0.1
muon_momentum = 0.95
muon_enable_nesterov = true
muon_ns_steps = 10
muon_adjust_lr_fn = "match_rms_adamw"
muon_hybrid_ns = true
swap_optimizer = true # 启用 Swap Optimizer,优化器状态卸载到 CPU
swap_merge_buckets = 4 # 每 4 个 bucket 合并一次 H2D/D2H,平衡性能与显存
注意事项
swap_optimizer与virtual_allocator互斥:两者不能同时启用,请根据场景选择其一- Swap Optimizer 推荐场景:显存受限但希望尽量保持训练性能的场景。通过异步 stream overlap 实现 H2D/D2H 与计算的并行,性能损失较小
- 断点续训兼容:Swap Optimizer 的
state_dict()/load_state_dict()已完整支持 checkpoint 保存与恢复,续训 loss 与连续训练完全一致