import warnings
import torch
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
def init_weight_quantization_reuse(pm, args):
from mindspeed.te.pytorch.fp8.reuse import optimizer_step_reuse_cleanup_wrapper
from mindspeed.te.pytorch.fp8.state_manager import FP8GlobalStateManager
FP8GlobalStateManager.set_weight_quantization_reuse_enabled(
bool(getattr(args, "fp8_reuse_quantized_weight", False))
)
if FP8GlobalStateManager.FP8_REUSE_QUANTIZED_WEIGHT:
pm.register_patch(
"megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step",
optimizer_step_reuse_cleanup_wrapper,
)
pm.register_patch(
"megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step_with_ready_grads",
optimizer_step_reuse_cleanup_wrapper,
)
pm.register_patch(
"megatron.core.optimizer.optimizer.ChainedOptimizer.step",
optimizer_step_reuse_cleanup_wrapper,
)
pm.register_patch(
"megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.step",
optimizer_step_reuse_cleanup_wrapper,
)
pm.register_patch(
"megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.step_with_ready_grads",
optimizer_step_reuse_cleanup_wrapper,
)
class TransformerEngineBasicFeature(MindSpeedFeature):
def __init__(self):
super().__init__('transformer-engine-basic', optimization_level=0)
def register_args(self, parser):
group = parser.add_argument_group(title=self.feature_name)
self.add_parser_argument_choices_value(parser, "--fp8-format", 'hif8')
self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'blockwise')
self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'mxfp8-32x32')
self.add_parser_argument_choices_value(
parser, "--moe-router-dtype", 'fp8'
)
group.add_argument(
'--no-use-gmm-fp8', action='store_false', help='not use GMM with scaling recipe.', dest='use_gmm_fp8'
)
group.add_argument(
'--te-comparison-with-cpu',
action='store_true',
default=False,
help='Compare the cast and quantmatmul of te on cpu and npu online.',
)
group.add_argument(
'--te-comparison-with-bf16',
action='store_true',
default=False,
help='Compare the cast and quantmatmul of te with bf16 online.',
)
group.add_argument(
'--te-gmm-mode',
choices=['performance', 'compatible'],
default='compatible',
help='Select the TE-GMM execution mode. '
'"performance": Enables high-performance optimizations. '
'"compatible": Default. Ensures compatibility with native TE behavior.',
dest='te_gmm_mode',
)
group.add_argument(
"--fp8-reuse-quantized-weight",
action="store_true",
default=False,
help="Reuse quantized FP8 weight tensors within one optimizer step.",
)
def validate_args(self, args):
if args.fp8 and args.transformer_impl == 'local':
raise AssertionError('FP8 just support TE implement.')
if args.use_ascend_coc and args.transformer_impl == 'transformer_engine':
raise AssertionError('transformer engine does not support ascend coc')
if args.use_ascend_mc2 and args.fp8 and args.fp8_recipe != 'mxfp8':
raise AssertionError('MC2 is supported only by the mxfp8 recipe in fp8.')
if getattr(args, "transformer_impl", "transformer_engine") == "transformer_engine" and getattr(
args, "use_legacy_models", False
):
raise AssertionError('transformer engine only support for mcore models')
if args.fp8 == 'hif8':
if args.fp8_recipe != 'tensorwise':
raise ValueError("hif8 only support tensorwise scaling type")
if args.use_gmm_fp8:
if args.fp8_recipe not in ('mxfp8', 'mxfp8-32x32', 'tensorwise', 'delayed'):
warnings.warn(
f"gmm fp8 only supports tensorwise, mxfp8, mxfp8-32x32, and delayed recipe, but {args.fp8_recipe} provided, "
f"using bf16 gmm instead."
)
if getattr(args, "fp8_reuse_quantized_weight", False) and not args.fp8:
raise ValueError("fp8_reuse_quantized_weight is only valid when FP8 training is enabled")
def pre_register_patches(self, patch_manager, args):
patch_manager.register_patch(
'transformer_engine.pytorch.tensor.QuantizedTensor', torch.nn.Module, create_dummy=True
)
def register_patches(self, patch_manager: MindSpeedPatchesManager, args):
from mindspeed.te.pytorch.module.checkpoint import (
transformer_block_forward,
transformer_block_checkpointed_forward,
)
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.TransformerBlock.forward', transformer_block_forward
)
if not (getattr(args, 'swap_attention', False) or getattr(args, 'recompute_method', False) == 'block'):
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
transformer_block_checkpointed_forward,
)
if not getattr(args, 'te_gmm_mode', 'compatible') == 'performance':
from mindspeed.te.pytorch.module.grouped_linear import (
MindSpeedTEGroupedLinear,
MindSpeedTEColumnParallelGroupedLinear,
MindSpeedTERowParallelGroupedLinear,
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEGroupedLinear', MindSpeedTEGroupedLinear
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
MindSpeedTEColumnParallelGroupedLinear,
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
MindSpeedTERowParallelGroupedLinear,
)
else:
from mindspeed.te.pytorch.module.performance_grouped_linear import (
MindSpeedTEPerformanceGroupedLinear,
MindSpeedTEPerformanceColumnParallelGroupedLinear,
MindSpeedTEPerformanceRowParallelGroupedLinear,
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEGroupedLinear', MindSpeedTEPerformanceGroupedLinear
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
MindSpeedTEPerformanceColumnParallelGroupedLinear,
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
MindSpeedTEPerformanceRowParallelGroupedLinear,
)
if getattr(args, "fp8_format", False):
from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import (
MindSpeedTEDotProductAttention,
)
from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import (
MindSpeedTELayerNormColumnParallelLinear,
)
from mindspeed.te.pytorch.module.grouped_linear import (
MindSpeedTEGroupedLinear,
MindSpeedTEColumnParallelGroupedLinear,
MindSpeedTERowParallelGroupedLinear,
)
from mindspeed.te.pytorch.module.linear import TERowParallelLinear, TEColumnParallelLinear
from mindspeed.te.pytorch.fp8.constants import Format, Fp8Recipe
from mindspeed.core.fp8_utils import get_fp8_context
from mindspeed.te.pytorch.fp8.fp8 import fp8_autocast, fp8_model_init
from mindspeed.te.pytorch.fp8.recipes import Float8CurrentScaling, MXFP8BlockScaling, TEDelayedScaling
from mindspeed.te.pytorch.fp8.padding import Fp8Padding, Fp8Unpadding
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEColumnParallelLinear', TEColumnParallelLinear
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TERowParallelLinear', TERowParallelLinear
)
if int(getattr(args, 'context_parallel_size', 1)) == 1:
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEDotProductAttention', MindSpeedTEDotProductAttention
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
MindSpeedTELayerNormColumnParallelLinear,
)
patch_manager.register_patch('transformer_engine.common.recipe.Format', Format)
patch_manager.register_patch('megatron.core.enums.Fp8Recipe', Fp8Recipe)
patch_manager.register_patch('megatron.core.fp8_utils.get_fp8_context', get_fp8_context)
patch_manager.register_patch('transformer_engine.pytorch.fp8_model_init', fp8_model_init)
patch_manager.register_patch('transformer_engine.pytorch.fp8_autocast', fp8_autocast)
patch_manager.register_patch("transformer_engine.common.recipe.Float8CurrentScaling", Float8CurrentScaling)
patch_manager.register_patch('transformer_engine.common.recipe.MXFP8BlockScaling', MXFP8BlockScaling)
patch_manager.register_patch(
"megatron.core.extensions.transformer_engine.TEDelayedScaling", TEDelayedScaling
)
patch_manager.register_patch("megatron.core.extensions.transformer_engine.Fp8Padding", Fp8Padding)
patch_manager.register_patch("megatron.core.extensions.transformer_engine.Fp8Unpadding", Fp8Unpadding)
from mindspeed.te.pytorch.module.checkpoint import te_checkpoint
patch_manager.register_patch('megatron.core.extensions.transformer_engine.te_checkpoint', te_checkpoint)
if not getattr(args, "moe_fb_overlap", False):
from mindspeed.core.transformer.moe.moe_feature.fb_overlap.adaptor import (
dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap,
get_moe_module_spec_wrapper,
)
patch_manager.register_patch(
'megatron.core.models.gpt.moe_module_specs.get_moe_module_spec', get_moe_module_spec_wrapper
)
if getattr(args, 'mtp_num_layers', None):
patch_manager.register_patch(
'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.forward',
dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap,
)
if getattr(args, "fp8_reuse_quantized_weight", False):
init_weight_quantization_reuse(patch_manager, args)
else:
from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import (
MindSpeedTEDotProductAttention,
)
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import (
MindSpeedTELayerNormColumnParallelLinear,
)
if not getattr(args, 'use_ascend_mc2', False):
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEColumnParallelLinear', ColumnParallelLinear
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TERowParallelLinear', RowParallelLinear
)
else:
from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2ColumnParallelLinear
from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2RowParallelLinear
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEColumnParallelLinear',
MindSpeedMC2ColumnParallelLinear,
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TERowParallelLinear', MindSpeedMC2RowParallelLinear
)
if int(getattr(args, 'context_parallel_size', 1)) == 1:
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TEDotProductAttention', MindSpeedTEDotProductAttention
)
patch_manager.register_patch(
'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
MindSpeedTELayerNormColumnParallelLinear,
)