import logging
from logging import getLogger
from functools import wraps
from typing import Callable, Dict, List, Optional
import torch
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron.training import get_args
from megatron.core import mpu
from megatron.core.utils import is_te_min_version, log_single_rank
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBuffer
from megatron.core.transformer.module import MegatronModule
from megatron.core.optimizer import (
_get_param_groups_and_buffers,
MegatronOptimizer,
ConstantGradScaler, DynamicGradScaler,
OptimizerConfig
)
logger = getLogger(__name__)
def get_megatron_optimizer_based_on_param_groups(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
param_groups: List,
per_model_buffers: Optional[Dict[int, List[_ParamAndGradBuffer]]] = None,
model_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_gloo: Optional[torch.distributed.ProcessGroup] = None,
ori_dp_group: Optional[torch.distributed.ProcessGroup] = None,
data_parallel_group_idx: Optional[int] = None,
distributed_optimizer_instance_id: Optional[int] = 0,
) -> MegatronOptimizer:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (list): list of model chunks.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
0.
Returns:
Instance of MegatronOptimizer.
"""
if param_groups:
if config.optimizer_cpu_offload:
if torch.__version__ < '2.3.0':
warnings.warn(
"CPU offload is recommended for PyTorch >= 2.3.0, "
"untested versions below this may have convergence issues."
)
gpu_optimizer_cls = Adam if config.optimizer == 'adam' else SGD
cpu_optimizer_cls = CPUAdam if config.optimizer == 'adam' else CPUSGD
if config.use_torch_optimizer_for_cpu_offload:
gpu_optimizer_cls = cpu_optimizer_cls
if config.optimizer == 'adam':
gpu_optimizer_cls = Adam
cpu_optimizer_cls = CPUAdam
optimizer_defaults = dict(
lr=config.lr,
weight_decay=config.weight_decay,
betas=(config.adam_beta1, config.adam_beta2),
eps=config.adam_eps,
bias_correction=True,
fused=True,
)
else:
gpu_optimizer_cls = SGD
cpu_optimizer_cls = CPUSGD
optimizer_defaults = dict(
lr=config.lr, weight_decay=config.weight_decay, momentum=config.sgd_momentum
)
optimizer = HybridDeviceOptimizer(
param_groups,
offload_fraction=config.optimizer_offload_fraction,
cpu_optimizer_cls=cpu_optimizer_cls,
gpu_optimizer_cls=gpu_optimizer_cls,
overlap_cpu_optimizer_d2h_h2d=config.overlap_cpu_optimizer_d2h_h2d,
pin_cpu_grads=config.pin_cpu_grads,
pin_cpu_params=config.pin_cpu_params,
param_update_in_fp32=True,
**optimizer_defaults,
)
init_state_fn = None
elif config.optimizer == 'adam':
kwargs = {
"params": param_groups,
"lr": config.lr,
"weight_decay": config.weight_decay,
"betas": (config.adam_beta1, config.adam_beta2),
"eps": config.adam_eps,
}
if config.use_precision_aware_optimizer:
kwargs.update(
{
"master_weights": True,
"use_decoupled_grad": True,
"master_weight_dtype": config.main_params_dtype,
"exp_avg_dtype": config.exp_avg_dtype,
"exp_avg_sq_dtype": config.exp_avg_sq_dtype,
}
)
if is_te_min_version("2.1.0.dev0"):
kwargs.update({"store_param_remainders": True})
optimizer = Adam(**kwargs)
def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
elif config.optimizer == 'sgd':
optimizer = SGD(
param_groups,
lr=config.lr,
weight_decay=config.weight_decay,
momentum=config.sgd_momentum,
)
init_state_fn = None
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
optimizer = None
init_state_fn = None
if config.fp16 or config.bf16 or config.use_distributed_optimizer:
grad_scaler = None
if config.loss_scale:
grad_scaler = ConstantGradScaler(config.loss_scale)
else:
if config.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=config.initial_loss_scale,
min_scale=config.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=config.loss_scale_window,
hysteresis=config.hysteresis,
)
optimizer_args = [
optimizer,
config,
grad_scaler,
init_state_fn,
]
from mindspeed_llm.core.high_availability import TTPReplicaOptimizer, TTPFP16ReplicaOptimizer
if config.use_distributed_optimizer:
optimizer = TTPReplicaOptimizer(
*optimizer_args,
model_chunks=model_chunks,
per_model_buffers=per_model_buffers,
data_parallel_group=data_parallel_group,
data_parallel_group_gloo=data_parallel_group_gloo,
data_parallel_group_idx=data_parallel_group_idx,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
ori_dp_group=ori_dp_group
)
else:
optimizer = TTPFP16ReplicaOptimizer(*optimizer_args, ori_dp_group=ori_dp_group)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
else:
raise Exception("High availability do not support FP32 Optimizer")
return optimizer
def get_megatron_optimizer(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
no_weight_decay_cond: Optional[Callable] = None,
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
use_gloo_process_groups: bool = True,
) -> MegatronOptimizer:
"""Retrieve the Megatron optimizer for model chunks.
We use separate optimizers for expert parameters and non-expert parameters.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
use_gloo_process_groups (bool): if false, disable use of Gloo process groups
in underlying Megatron optimizers.
Returns:
Instance of MegatronOptimizer.
"""
log_single_rank(logger, logging.INFO, f'Setting up optimizer with config {config}')
if config.overlap_param_gather_with_optimizer_step:
all_dense_model_chunks = [[model_chunks[0]], model_chunks[1:]]
overlap_param_gather_with_optimizer_step_flags = [True, False]
else:
all_dense_model_chunks = [model_chunks]
overlap_param_gather_with_optimizer_step_flags = [False]
model_parallel_rank = torch.distributed.get_rank(mpu.get_model_parallel_group())
if torch.distributed.get_world_size(
mpu.get_data_parallel_group(with_context_parallel=True, partial_data_parallel=False)
) > torch.distributed.get_world_size(
mpu.get_data_parallel_group(with_context_parallel=True, partial_data_parallel=True)
):
distributed_optimizer_instance_id = torch.distributed.get_rank(
mpu.get_inter_partial_data_parallel_group()
)
else:
distributed_optimizer_instance_id = 0
from mindspeed_llm.core.high_availability import TTPReplicaChainedOptimizer
from mindspeed_llm.core.high_availability import (ttp_get_dp_cp_replica_group, ttp_get_dp_cp_replica_group_gloo,
ttp_get_dp_ep_replica_group, ttp_get_dp_ep_replica_group_gloo)
optimizers = []
model_chunk_offset = 0
ddp_config = model_chunks[0].ddp_config
if ddp_config.use_custom_fsdp:
for model_chunk, _ in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
model_chunk,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: True,
buffer_name='buffers',
)
optimizers.append(
get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunk,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=ttp_get_dp_cp_replica_group(),
data_parallel_group_gloo=ttp_get_dp_cp_replica_group_gloo(),
ori_dp_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_idx=model_parallel_rank,
)
)
model_chunk_offset += 1
if len(optimizers) == 1:
return optimizers[0]
return TTPReplicaChainedOptimizer(optimizers)
for dense_model_chunks, overlap_param_gather_with_optimizer_step in zip(
all_dense_model_chunks, overlap_param_gather_with_optimizer_step_flags
):
param_groups, buffers = _get_param_groups_and_buffers(
dense_model_chunks,
model_chunk_offset=model_chunk_offset,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: not g['is_expert_parallel'],
buffer_name='buffers',
)
for model_chunk in dense_model_chunks:
model_chunk.overlap_param_gather_with_optimizer_step = (
overlap_param_gather_with_optimizer_step
)
optimizers.append(
get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=dense_model_chunks,
param_groups=param_groups,
per_model_buffers=buffers,
model_parallel_group=mpu.get_model_parallel_group(),
data_parallel_group=ttp_get_dp_cp_replica_group(),
data_parallel_group_gloo=ttp_get_dp_cp_replica_group_gloo(),
ori_dp_group=mpu.get_data_parallel_group(with_context_parallel=True),
data_parallel_group_idx=model_parallel_rank,
distributed_optimizer_instance_id=distributed_optimizer_instance_id,
)
)
model_chunk_offset += 1
moe_param_groups, moe_buffers = _get_param_groups_and_buffers(
model_chunks,
model_chunk_offset=0,
config=config,
no_weight_decay_cond=no_weight_decay_cond,
scale_lr_cond=scale_lr_cond,
lr_mult=lr_mult,
filter_fn=lambda g: g['is_expert_parallel'],
buffer_name='expert_parallel_buffers',
)
if len(moe_param_groups) > 0:
model_parallel_rank = torch.distributed.get_rank(
mpu.get_expert_tensor_model_pipeline_parallel_group()
)
optimizers.append(
get_megatron_optimizer_based_on_param_groups(
config,
model_chunks=model_chunks,
param_groups=moe_param_groups,
per_model_buffers=moe_buffers,
model_parallel_group=mpu.get_expert_tensor_model_pipeline_parallel_group(),
data_parallel_group=ttp_get_dp_ep_replica_group(),
data_parallel_group_gloo=ttp_get_dp_ep_replica_group_gloo(),
ori_dp_group=mpu.get_data_modulo_expert_parallel_group(),
data_parallel_group_idx=model_parallel_rank,
)
)
if len(optimizers) == 1:
return optimizers[0]
return TTPReplicaChainedOptimizer(optimizers)
def get_megatron_optimizer_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
return get_megatron_optimizer(*args, **kwargs)
return wrapper