from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class FusedEmaAdamwFeature(MindSpeedFeature):
def __init__(self):
super().__init__('optimizer-selection')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--ema-decay', type=float, default=0.9999,
help='Set ema_decay of fused_ema_adamw optimizer.')
def pre_validate_args(self, args):
if args.optimizer_selection == 'fused_ema_adamw' and args.ema_decay < 0 or args.ema_decay > 1:
raise AssertionError("ema_decay must be in the range [0, 1].")
if not hasattr(args, 'disable_gloo_group'):
setattr(args, 'disable_gloo_group', False)
def pre_register_patches(self, patch_manager, args):
if args.optimization_level >= 0 and args.optimizer_selection == 'fused_ema_adamw':
from mindspeed.core.optimizer.fused_ema_adamw.fused_ema_adamw import FusedEmaAdamW as AdamW
patch_manager.register_patch(
'apex.optimizers.FusedAdam', AdamW, create_dummy=True)
def register_patches(self, patch_manager, args):
if args.optimization_level >= 2 and args.optimizer_selection == 'fused_ema_adamw':
from mindspeed.core.optimizer.fused_ema_adamw.adaptor import generate_state_dict_ema_wrapper, save_checkpoint_ema_wrapper
from mindspeed.core.optimizer.fused_ema_adamw.adaptor import ema_distrib_optimizer_init_wrapper
patch_manager.register_patch(
'megatron.training.checkpointing.save_checkpoint', save_checkpoint_ema_wrapper)
patch_manager.register_patch(
'megatron.training.checkpointing.generate_state_dict', generate_state_dict_ema_wrapper)
patch_manager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
ema_distrib_optimizer_init_wrapper)
if hasattr(args, "ema_decay"):
from mindspeed.core.optimizer.fused_ema_adamw.adaptor import get_megatron_optimizer_func_wrapper
patch_manager.register_patch('megatron.core.optimizer.get_megatron_optimizer',
get_megatron_optimizer_func_wrapper)