from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
_QUANT_STATE_CHOICES = ('fp8', 'hif8', 'mxfp8')
class LowPrecisionOptimizerFeature(MindSpeedFeature):
def __init__(self):
super().__init__('low-precision-optimizer', optimization_level=0)
@staticmethod
def _normalize_flags(args):
quant_states = getattr(args, 'quant_states', None)
if isinstance(quant_states, str):
quant_states = quant_states.lower()
if quant_states is not None and quant_states not in _QUANT_STATE_CHOICES:
raise AssertionError(
f"Low precision optimizer only supports quant_states {_QUANT_STATE_CHOICES}, got '{quant_states}'."
)
quant_states_enabled = bool(quant_states)
quant_grads_requested = getattr(args, 'quant_grads', False)
use_quant = quant_states_enabled or quant_grads_requested
if use_quant:
args.use_quant_optimizer = True
args.quant_states = quant_states
args.quant_states_enabled = quant_states_enabled
if quant_grads_requested:
args.quant_grads = True
quant_grad_dtype = getattr(args, 'quant_grads_dtype', None)
if quant_grad_dtype in {'fp16'}:
args.quant_grads_dtype = quant_grad_dtype
else:
args.quant_grads_dtype = 'fp16'
else:
args.quant_grads = False
args.quant_grads_dtype = None
return True
args.quant_states = quant_states
args.quant_states_enabled = False
if getattr(args, 'quant_grads', False):
args.quant_grads = False
args.quant_grads_dtype = None
args.use_quant_optimizer = False
return False
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
option_strings = {opt for action in parser._actions for opt in action.option_strings}
if '--quant-states' not in option_strings:
group.add_argument(
'--quant-states',
choices=_QUANT_STATE_CHOICES,
default=None,
help='Select quantization format for optimizer states (default: disabled).',
)
if '--quant-grads' not in option_strings:
group.add_argument(
'--quant-grads',
action='store_true',
help='Enable gradient quantization; dtype inferred from main_grads or defaults to fp16.',
)
def validate_args(self, args):
self._normalize_flags(args)
if getattr(args, "quant_grads", False) and getattr(args, "gradient_accumulation_fusion", False):
raise AssertionError(
"quant_grads is incompatible with "
"gradient_accumulation_fusion. Please use "
"--no-gradient-accumulation-fusion or turn off --quant-grads."
)
if getattr(args, 'swap_optimizer', False) and (
getattr(args, 'quant_grads', False) or getattr(args, 'quant_states', False)
):
raise AssertionError(
"--quant-states/--quant-grads are incompatible with --swap-optimizer."
"To use --quant-states/--quant-grads, please turn off --swap-optimizer."
)
def register_patches(self, patch_manager, args):
quant_enabled = self._normalize_flags(args)
if not quant_enabled:
return
quant_grads_enabled = bool(getattr(args, 'quant_grads', False))
patch_specs = []
import mindspeed.core.optimizer.low_precision.quant_optimizer_hooks as optimizer_hooks
import mindspeed.core.optimizer.low_precision.quant_distributed_hooks as distributed_hooks
from mindspeed.core.optimizer.low_precision import quant_grad_clip as grad_clip
from mindspeed.core.optimizer.low_precision import param_and_grad_buffer
from mindspeed.core.optimizer.low_precision import finalize_model_grads as quant_finalize
from mindspeed.core.models.gpt.gpt_model import gptmodel_init_wrapper
from mindspeed.core.tensor_parallel.layers import copy_tensor_model_parallel_attributes_wrapper
from mindspeed.core.optimizer.low_precision.language_model import transformer_language_model_init_wrapper
patch_specs.extend(
[
(
'megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.prepare_grads',
optimizer_hooks.prepare_grads_impl,
True,
),
(
'megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step_with_ready_grads',
optimizer_hooks.step_with_ready_grads_impl,
True,
),
(
'megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step',
optimizer_hooks.mixed_precision_optimizer_step_impl,
True,
),
(
'megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params.__init__',
optimizer_hooks.reuse_fp32_param_init_wrapper,
True,
),
(
'megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__',
optimizer_hooks.optimizer_config_init_wrapper,
True,
),
(
'megatron.core.optimizer.optimizer_config.OptimizerConfig.__post_init__',
optimizer_hooks.optimizer_config_post_init_wrapper,
True,
),
(
'megatron.core.optimizer._get_megatron_optimizer_based_on_param_groups',
optimizer_hooks.get_optimizer_builder_wrapper,
True,
),
(
'megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.start_grad_sync',
param_and_grad_buffer.quant_grad_start_grad_sync_wrapper,
False,
),
(
'megatron.core.distributed.param_and_grad_buffer._ParamAndGradBuffer.__init__',
param_and_grad_buffer.quant_grad_param_and_grad_buffer_init_wrapper,
False,
),
(
'megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads',
quant_finalize._allreduce_word_embedding_grads,
False,
),
(
'megatron.core.distributed.finalize_model_grads._allreduce_position_embedding_grads',
quant_finalize._allreduce_position_embedding_grads,
False,
),
(
'megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads',
quant_finalize._allreduce_layernorm_grads,
False,
),
(
'megatron.core.distributed.finalize_model_grads._allreduce_conditional_embedding_grads',
quant_finalize._allreduce_conditional_embedding_grads,
False,
),
(
'megatron.core.distributed.finalize_model_grads._update_router_expert_bias',
quant_finalize._update_router_expert_bias,
False,
),
(
'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_backward_post_hook',
distributed_hooks.ddp_make_backward_post_hook_wrapper,
False,
),
(
'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
optimizer_hooks.distributed_optimizer_init_wrapper,
True,
),
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__',
gptmodel_init_wrapper,
True,
),
(
'megatron.core.tensor_parallel.layers.copy_tensor_model_parallel_attributes',
copy_tensor_model_parallel_attributes_wrapper,
True,
),
(
'megatron.legacy.model.language_model.TransformerLanguageModel.__init__',
transformer_language_model_init_wrapper,
True,
),
]
)
if quant_grads_enabled:
patch_specs.extend(
[
(
'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._collect_main_grad_data_for_unscaling',
distributed_hooks.collect_main_grad_data_for_unscaling_quant,
True,
),
(
'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer._copy_model_grads_to_main_grads',
distributed_hooks.copy_model_grads_to_main_grads_quant,
True,
),
(
'megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params._collect_main_grad_data_for_unscaling',
optimizer_hooks.collect_main_grad_data_for_unscaling_wrapper,
True,
),
(
'megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params._copy_model_grads_to_main_grads',
optimizer_hooks.copy_model_grads_to_main_grads,
True,
),
(
'megatron.core.optimizer.optimizer.MixedPrecisionOptimizer._unscale_main_grads_and_check_for_nan',
optimizer_hooks.unscale_main_grads_and_check_for_nan,
True,
),
(
'megatron.core.optimizer.optimizer.MegatronOptimizer.get_main_grads_for_grad_norm',
optimizer_hooks.get_main_grads_for_grad_norm,
True,
),
(
'megatron.core.optimizer.optimizer._zero_grad_group_helper',
optimizer_hooks.zero_grad_group_helper_wrapper,
True,
),
(
'megatron.core.optimizer.clip_grads.get_grad_norm_fp32',
grad_clip.get_grad_norm_fp32,
True,
),
(
'megatron.core.optimizer.clip_grads.clip_grad_by_total_norm_fp32',
grad_clip.clip_grad_by_total_norm_fp32_wrapper,
True,
),
]
)
for target, func, force in patch_specs:
patch_manager.register_patch(target, func, force_patch=force)