import warnings
import torch
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
class TransformerEngineBasicFeature(MindSpeedFeature):
def __init__(self):
super().__init__('transformer-engine-basic', optimization_level=0)
def register_args(self, parser):
group = parser.add_argument_group(title=self.feature_name)
self.add_parser_argument_choices_value(parser, "--fp8-format", 'hif8')
self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'groupwise')
self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'blockwise')
self.add_parser_argument_choices_value(parser, "--moe-router-dtype", 'fp8')
group.add_argument('--no-use-gmm-fp8', action='store_false',
help='not use GMM with scaling recipe.', dest='use_gmm_fp8')
group.add_argument('--te-comparison-with-cpu', action='store_true',
default=False, help='Compare the cast and quantmatmul of te on cpu and npu online.')
group.add_argument('--te-comparison-with-bf16', action='store_true',
default=False, help='Compare the cast and quantmatmul of te with bf16 online.')
group.add_argument('--te-gmm-mode',
choices=['performance', 'compatible'],
default='compatible',
help='Select the TE-GMM execution mode. '
'"performance": Enables high-performance optimizations. '
'"compatible": Default. Ensures compatibility with native TE behavior.',
dest='te_gmm_mode')
group.add_argument("--fp8-reuse-quantized-weight", action="store_true",
default=False, help="Reuse quantized FP8 weight tensors within one optimizer step.",
)
def validate_args(self, args):
if args.fp8 and args.transformer_impl == 'local':
raise AssertionError('FP8 just support TE implement.')
if args.use_ascend_coc and args.transformer_impl == 'transformer_engine':
raise AssertionError('transformer engine does not support ascend coc')
if args.use_ascend_mc2 and args.fp8 and args.fp8_recipe != 'mxfp8':
raise AssertionError('MC2 is supported only by the mxfp8 recipe in fp8.')
if (getattr(args, "transformer_impl", "transformer_engine") == "transformer_engine"
and getattr(args, "use_legacy_models", False)):
raise AssertionError('transformer engine only support for mcore models')
if args.fp8 == 'hif8':
if args.fp8_recipe != 'tensorwise':
raise ValueError("hif8 only support tensorwise scaling type")
if args.use_gmm_fp8:
if args.fp8_recipe not in ('mxfp8', 'tensorwise', 'delayed'):
warnings.warn(
f"gmm fp8 only supports tensorwise, mxfp8, and delayed recipe, but {args.fp8_recipe} provided, "
f"using bf16 gmm instead.")
if getattr(args, "fp8_reuse_quantized_weight", False) and not args.fp8:
raise ValueError("fp8_reuse_quantized_weight is only valid when FP8 training is enabled")
def pre_register_patches(self, pm, args):
pm.register_patch('transformer_engine.pytorch.tensor.QuantizedTensor', torch.nn.Module, create_dummy=True)
def register_patches(self, pm: MindSpeedPatchesManager, args):
if not getattr(args, 'te_gmm_mode', 'compatible') == 'performance':
from mindspeed.te.pytorch.module.grouped_linear import MindSpeedTEGroupedLinear, \
MindSpeedTEColumnParallelGroupedLinear, MindSpeedTERowParallelGroupedLinear
pm.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear', MindSpeedTEGroupedLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
MindSpeedTEColumnParallelGroupedLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
MindSpeedTERowParallelGroupedLinear)
else:
from mindspeed.te.pytorch.module.performance_grouped_linear import MindSpeedTEPerformanceGroupedLinear, \
MindSpeedTEPerformanceColumnParallelGroupedLinear, MindSpeedTEPerformanceRowParallelGroupedLinear
pm.register_patch('megatron.core.extensions.transformer_engine.TEGroupedLinear',
MindSpeedTEPerformanceGroupedLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
MindSpeedTEPerformanceColumnParallelGroupedLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
MindSpeedTEPerformanceRowParallelGroupedLinear)
if getattr(args, "fp8_format", False):
from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import \
MindSpeedTEDotProductAttention
from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import \
MindSpeedTELayerNormColumnParallelLinear
from mindspeed.te.pytorch.module.linear import TERowParallelLinear, TEColumnParallelLinear
from mindspeed.te.pytorch.fp8.constants import Format, Fp8Recipe
from mindspeed.core.fp8_utils import get_fp8_context
from mindspeed.te.pytorch.fp8.fp8 import fp8_autocast, fp8_model_init
from mindspeed.te.pytorch.fp8.recipes import Float8CurrentScaling, MXFP8BlockScaling, TEDelayedScaling
from mindspeed.te.pytorch.fp8.padding import Fp8Padding, Fp8Unpadding
pm.register_patch('megatron.core.extensions.transformer_engine.TEColumnParallelLinear',
TEColumnParallelLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TERowParallelLinear', TERowParallelLinear)
if int(getattr(args, 'context_parallel_size', 1)) == 1:
pm.register_patch('megatron.core.extensions.transformer_engine.TEDotProductAttention',
MindSpeedTEDotProductAttention)
pm.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
MindSpeedTELayerNormColumnParallelLinear)
pm.register_patch('transformer_engine.common.recipe.Format', Format)
pm.register_patch('megatron.core.enums.Fp8Recipe', Fp8Recipe)
pm.register_patch('megatron.core.fp8_utils.get_fp8_context', get_fp8_context)
pm.register_patch('transformer_engine.pytorch.fp8_model_init', fp8_model_init)
pm.register_patch('transformer_engine.pytorch.fp8_autocast', fp8_autocast)
pm.register_patch("transformer_engine.common.recipe.Float8CurrentScaling", Float8CurrentScaling)
pm.register_patch('transformer_engine.common.recipe.MXFP8BlockScaling', MXFP8BlockScaling)
pm.register_patch("megatron.core.extensions.transformer_engine.TEDelayedScaling", TEDelayedScaling)
pm.register_patch("megatron.core.extensions.transformer_engine.Fp8Padding", Fp8Padding)
pm.register_patch("megatron.core.extensions.transformer_engine.Fp8Unpadding", Fp8Unpadding)
from mindspeed.te.pytorch.module.checkpoint import transformer_block_checkpointed_forward, te_checkpoint
if not (
getattr(args, 'swap_attention', False)
or getattr(args, 'recompute_method', False) == 'block'
):
pm.register_patch('megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
transformer_block_checkpointed_forward)
pm.register_patch('megatron.core.extensions.transformer_engine.te_checkpoint', te_checkpoint)
if not getattr(args, "moe_fb_overlap", False):
from mindspeed.core.transformer.moe.moe_feature.fb_overlap.adaptor import (
dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap, get_moe_module_spec_wrapper)
pm.register_patch('megatron.core.models.gpt.moe_module_specs.get_moe_module_spec',
get_moe_module_spec_wrapper)
if getattr(args, 'mtp_num_layers', None):
pm.register_patch(
'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.forward',
dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap)
if getattr(args, "fp8_reuse_quantized_weight", False):
from mindspeed.features_manager.megatron_basic.transformer_engine_basic import init_weight_quantization_reuse
init_weight_quantization_reuse(pm, args)
else:
from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import \
MindSpeedTEDotProductAttention
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import \
MindSpeedTELayerNormColumnParallelLinear
if not getattr(args, 'use_ascend_mc2', False):
pm.register_patch('megatron.core.extensions.transformer_engine.TEColumnParallelLinear',
ColumnParallelLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TERowParallelLinear', RowParallelLinear)
else:
from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2ColumnParallelLinear
from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2RowParallelLinear
pm.register_patch('megatron.core.extensions.transformer_engine.TEColumnParallelLinear',
MindSpeedMC2ColumnParallelLinear)
pm.register_patch('megatron.core.extensions.transformer_engine.TERowParallelLinear',
MindSpeedMC2RowParallelLinear)
if int(getattr(args, 'context_parallel_size', 1)) == 1:
pm.register_patch('megatron.core.extensions.transformer_engine.TEDotProductAttention',
MindSpeedTEDotProductAttention)
pm.register_patch('megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
MindSpeedTELayerNormColumnParallelLinear)