import logging
from typing import Any, Callable, List, Optional
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron.core.dist_checkpointing.dict_utils import nested_values
from megatron.core.dist_checkpointing.mapping import LocalNonpersistentObject, ShardedStateDict
from megatron.core.optimizer.clip_grads import count_zeros_fp32, get_grad_norm_fp32
from megatron.core.optimizer.optimizer import (
ChainedOptimizer,
Float16OptimizerWithFloat16Params,
FP32Optimizer,
)
from mindspeed.core.optimizer.muon.utils import LegacyProcessGroupCollection, get_pg_rank, get_pg_size
logger = logging.getLogger(__name__)
class LayerWiseDistributedOptimizer(ChainedOptimizer):
"""Layer-wise distributed optimizer for Megatron-core models.
Experimental distributed optimizer wrapper that distributes weight to DP ranks by layer.
Implemented as ChainedOptimizer to support multiple optimizers (e.g. muon + adamW)
When using, keep all megatron distributed-optimizer related options OFF.
How LayerWiseDistributedOptimizer work:
1. weights are splited into lists and each rank only keep its shard in its optimizer
2. Megatron DDP handle allreduce grad, note that each rank have full model and grad
3. optimizer is already modified so only param belong to this DP rank is updated
4. grad_norm and zero counting will reduce metrics globally in step function
5. Do regular update with chained optimizers, modified optimizer only update shard
6. allgather updated params to every rank
"""
def __init__(
self,
optimizers: List[torch.optim.Optimizer],
config,
pg_collection: Optional[Any] = None,
init_state_fn_list: Optional[List[Callable]] = None,
model_chunks: Optional[List] = None,
) -> None:
"""
Initialize LayerWiseDistributedOptimizer.
Args:
optimizers: List of MegatronOptimizers.
config: OptimizerConfig.
pg_collection: ProcessGroupCollection.
init_state_fn_list: List of init state functions.
model_chunks: DDP-wrapped model chunks (needed for overlap_param_gather).
"""
if pg_collection is None:
self.pg_collection = LegacyProcessGroupCollection()
else:
self.pg_collection = pg_collection
self.dp_group = self.pg_collection.dp_cp
self.expt_dp_group = self.pg_collection.expt_dp
self.dp_cp_params_list = None
self.dp_params_list = None
self.expt_dp_params_list = None
self.overlap_param_gather = getattr(config, "overlap_param_gather", False)
self.model_chunks = model_chunks
self.shard_params(optimizers)
if self.overlap_param_gather:
if model_chunks is None:
raise ValueError("model_chunks must be provided if overlap_param_gather is True")
self.set_bucket_layerwise_params_list(model_chunks)
if init_state_fn_list is None:
init_state_fn_list = [None] * len(optimizers)
if len(init_state_fn_list) != len(optimizers):
raise ValueError("init_state_fn_list must match optimizers length")
wrapped_optimizers = []
for optimizer, init_state_fn in zip(optimizers, init_state_fn_list):
if isinstance(optimizer, (Float16OptimizerWithFloat16Params, FP32Optimizer)):
raise TypeError(
"LayerWiseDistributedOptimizer expects base torch optimizers, "
f"got {type(optimizer).__name__}. Do not pre-wrap with Megatron optimizers."
)
if getattr(config, "bf16", False):
wrapped_optimizer = Float16OptimizerWithFloat16Params(optimizer, config, None, init_state_fn)
else:
wrapped_optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(wrapped_optimizer, "tp_group", self.pg_collection.tp)
wrapped_optimizers.append(wrapped_optimizer)
super().__init__(wrapped_optimizers)
self.config = config
def shard_params(self, optimizers: List[torch.optim.Optimizer]) -> None:
"""Shard all params into lists by rank."""
dp_cp_size = get_pg_size(self.pg_collection.dp_cp)
expt_dp_size = get_pg_size(self.pg_collection.expt_dp)
if dp_cp_size == 1 and expt_dp_size == 1:
return
dp_cp_idx, expt_dp_idx = 0, 0
dp_cp_loop = list(range(dp_cp_size)) + list(range(dp_cp_size))[::-1]
expt_dp_loop = list(range(expt_dp_size)) + list(range(expt_dp_size))[::-1]
self.dp_cp_params_list = [[] for _ in range(dp_cp_size)]
self.expt_dp_params_list = [[] for _ in range(expt_dp_size)]
param_groups = []
for optimizer in optimizers:
param_groups += optimizer.param_groups
param_list = []
for group_index, group in enumerate(param_groups):
for param in group["params"]:
param_list.append((param, group_index))
param_list.sort(key=lambda item: item[0].numel())
param_groups_this_rank = [[] for _ in param_groups]
dp_cp_rank = get_pg_rank(self.pg_collection.dp_cp)
expt_dp_rank = get_pg_rank(self.pg_collection.expt_dp)
for param, group_index in param_list:
if param_groups[group_index].get("is_expert_parallel", False):
owner = expt_dp_loop[expt_dp_idx]
if owner == expt_dp_rank:
param_groups_this_rank[group_index].append(param)
self.expt_dp_params_list[owner].append(param)
expt_dp_idx = (expt_dp_idx + 1) % len(expt_dp_loop)
else:
owner = dp_cp_loop[dp_cp_idx]
if owner == dp_cp_rank:
param_groups_this_rank[group_index].append(param)
self.dp_cp_params_list[owner].append(param)
dp_cp_idx = (dp_cp_idx + 1) % len(dp_cp_loop)
for group, params in zip(param_groups, param_groups_this_rank):
group["params"] = params
self.dp_params_list = self.dp_cp_params_list
if dp_cp_size == 1:
self.dp_cp_params_list = None
self.dp_params_list = None
if expt_dp_size == 1 or not self.expt_dp_params_list or len(self.expt_dp_params_list[0]) == 0:
self.expt_dp_params_list = None
def set_bucket_layerwise_params_list(self, model_chunks):
"""Map sharded params to DDP buckets for async all-gather.
For each bucket in each model chunk's bucket groups, build per-rank param lists
by cross-referencing the layer-wise sharded param lists with the bucket's params.
Args:
model_chunks: DDP-wrapped model chunks with bucket_groups.
"""
for model_chunk in model_chunks:
for group in model_chunk.bucket_groups:
for bucket in group.buckets:
if self.dp_cp_params_list is not None:
bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.dp_cp))]
for bucket_list, full_params_list in zip(bucket_params_list, self.dp_cp_params_list):
for param in full_params_list:
if param in bucket.params:
bucket_list.append(param)
else:
bucket_params_list = [list(bucket.params_list)]
bucket.set_layerwise_params_list(bucket_params_list)
for group in getattr(model_chunk, "expert_parallel_bucket_groups", []):
for bucket in group.buckets:
if self.expt_dp_params_list is not None:
bucket_params_list = [[] for _ in range(get_pg_size(self.pg_collection.expt_dp))]
for bucket_list, full_params_list in zip(bucket_params_list, self.expt_dp_params_list):
for param in full_params_list:
if param in bucket.params:
bucket_list.append(param)
else:
bucket_params_list = [list(bucket.params_list)]
bucket.set_layerwise_params_list(bucket_params_list)
@torch.no_grad()
def allgather_params(self) -> None:
"""All-gather updated params from all ranks."""
def _allgather_helper(params_list, group):
device = params_list[0][0].device
dtype = params_list[0][0].dtype
rank = get_pg_rank(group)
dp_size = get_pg_size(group)
src = (
_flatten_dense_tensors(params_list[rank])
if len(params_list[rank]) > 0
else torch.empty(0, device=device, dtype=dtype)
)
flat_sizes = [sum(p.numel() for p in params) for params in params_list]
if max(flat_sizes) == 0:
return
gather_list = []
for i in range(dp_size):
if i == rank:
gather_list.append(src)
else:
gather_list.append(torch.empty(flat_sizes[i], device=device, dtype=dtype))
torch.distributed.all_gather(gather_list, src, group=group)
for idx, params in enumerate(params_list):
if len(params) == 0 or idx == rank:
continue
updated_params = _unflatten_dense_tensors(gather_list[idx], params)
for updated_p, model_p in zip(updated_params, params):
model_p.data.copy_(updated_p)
if self.dp_cp_params_list is not None:
_allgather_helper(self.dp_cp_params_list, self.pg_collection.dp_cp)
if self.expt_dp_params_list is not None:
_allgather_helper(self.expt_dp_params_list, self.pg_collection.expt_dp)
@torch.no_grad()
def get_grad_norm(self):
grads_for_norm = []
for optimizer in self.chained_optimizers:
grads_for_norm += optimizer.get_main_grads_for_grad_norm()
return get_grad_norm_fp32(grads_for_norm, grad_stats_parallel_group=None)
@torch.no_grad()
def count_zeros(self):
params = []
for optimizer in self.chained_optimizers:
params += optimizer.get_parameters()
return count_zeros_fp32(
params,
grad_stats_parallel_group=None,
use_decoupled_grad=getattr(
self.config,
"use_precision_aware_optimizer_no_fp8_or_ds_fp8",
getattr(self.config, "use_precision_aware_optimizer", False),
),
)
@torch.no_grad()
def step(self):
"""step function for layer-wise optimizer."""
update_successful, grad_norm, num_zeros_in_grad = super().step()
if not self.overlap_param_gather:
self.allgather_params()
return update_successful, grad_norm, num_zeros_in_grad
def load_state_dict(self, state_dict):
if len(self.chained_optimizers) == 1:
wrapped_state_dict = {1: state_dict}
else:
wrapped_state_dict = state_dict
for state in wrapped_state_dict.values():
if "fp32_from_fp16_params" in state and isinstance(state["fp32_from_fp16_params"], dict):
logger.info("[layerwise] converting fp32_from_fp16_params from dict to list")
state["fp32_from_fp16_params"] = [value for _, value in sorted(state["fp32_from_fp16_params"].items())]
super().load_state_dict(state_dict)
def sharded_state_dict(self, model_sharded_state_dict: ShardedStateDict, is_loading: bool = False, **kwargs):
"""
Sharded state dict for torch_dist format checkpointing.
For fixed DP usage only, set replica_id to 0 for all ShardedTensor.
"""
sharded_state_dict = super().sharded_state_dict(model_sharded_state_dict, is_loading, **kwargs)
for sharded_base in nested_values(sharded_state_dict):
if hasattr(sharded_base, "replica_id"):
if not (isinstance(sharded_base.replica_id, int) or len(sharded_base.replica_id) == 3):
raise ValueError(f"Expected replica_id as int or (PP, TP, DP), got: {sharded_base}")
sharded_base.replica_id = (
0 if isinstance(sharded_base.replica_id, int) else (*sharded_base.replica_id[:2], 0)
)
if len(self.chained_optimizers) == 1:
wrapped_sharded_state_dict = {1: sharded_state_dict}
else:
wrapped_sharded_state_dict = sharded_state_dict
for state in wrapped_sharded_state_dict.values():
if "fp32_from_fp16_params" in state:
state["fp32_from_fp16_params"][:] = [
group if group else LocalNonpersistentObject(group) for group in state["fp32_from_fp16_params"]
]
state["fp32_from_fp16_params"] = {
idx: value for idx, value in enumerate(state["fp32_from_fp16_params"])
}
if not state["optimizer"]["state"]:
state["optimizer"]["state"] = LocalNonpersistentObject(state["optimizer"]["state"])
for idx, group in enumerate(state["optimizer"]["param_groups"]):
local_params = group.pop("params")
group["params"] = bool(local_params.unwrap())
all_rank_groups = [None for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(all_rank_groups, group)
nonempty_rank_group = next(
(rank_group for rank_group in all_rank_groups if rank_group["params"]),
group,
)
nonempty_rank_group["params"] = local_params
state["optimizer"]["param_groups"][idx] = nonempty_rank_group
return sharded_state_dict
def save_state_dict_to_file(self, filename: str) -> None:
"""Save the parameter state of the optimizer. For torch format only.
Args:
filename: The filename to save the parameter state.
"""
torch.save(super().state_dict(), filename)
def load_state_dict_from_file(self, filename: str) -> None:
"""Load the parameter state of the optimizer. For torch format only."""
super().load_state_dict(torch.load(filename, weights_only=True))