# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
import warnings

import torch

from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager


def init_weight_quantization_reuse(pm, args):
    from mindspeed.te.pytorch.fp8.reuse import optimizer_step_reuse_cleanup_wrapper
    from mindspeed.te.pytorch.fp8.state_manager import FP8GlobalStateManager

    FP8GlobalStateManager.set_weight_quantization_reuse_enabled(
        bool(getattr(args, "fp8_reuse_quantized_weight", False))
    )

    if FP8GlobalStateManager.FP8_REUSE_QUANTIZED_WEIGHT:
        pm.register_patch(
            "megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step",
            optimizer_step_reuse_cleanup_wrapper,
        )
        pm.register_patch(
            "megatron.core.optimizer.optimizer.MixedPrecisionOptimizer.step_with_ready_grads",
            optimizer_step_reuse_cleanup_wrapper,
        )
        pm.register_patch(
            "megatron.core.optimizer.optimizer.ChainedOptimizer.step",
            optimizer_step_reuse_cleanup_wrapper,
        )
        pm.register_patch(
            "megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.step",
            optimizer_step_reuse_cleanup_wrapper,
        )
        pm.register_patch(
            "megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.step_with_ready_grads",
            optimizer_step_reuse_cleanup_wrapper,
        )


class TransformerEngineBasicFeature(MindSpeedFeature):
    def __init__(self):
        super().__init__('transformer-engine-basic', optimization_level=0)

    def register_args(self, parser):
        group = parser.add_argument_group(title=self.feature_name)
        self.add_parser_argument_choices_value(parser, "--fp8-format", 'hif8')
        self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'blockwise')
        self.add_parser_argument_choices_value(parser, "--fp8-recipe", 'mxfp8-32x32')
        self.add_parser_argument_choices_value(
            parser, "--moe-router-dtype", 'fp8'
        )  # Validation argument for router dtype.

        group.add_argument(
            '--no-use-gmm-fp8', action='store_false', help='not use GMM with scaling recipe.', dest='use_gmm_fp8'
        )
        group.add_argument(
            '--te-comparison-with-cpu',
            action='store_true',
            default=False,
            help='Compare the cast and quantmatmul of te on cpu and npu online.',
        )
        group.add_argument(
            '--te-comparison-with-bf16',
            action='store_true',
            default=False,
            help='Compare the cast and quantmatmul of te with bf16 online.',
        )
        group.add_argument(
            '--te-gmm-mode',
            choices=['performance', 'compatible'],
            default='compatible',
            help='Select the TE-GMM execution mode. '
            '"performance": Enables high-performance optimizations. '
            '"compatible": Default. Ensures compatibility with native TE behavior.',
            dest='te_gmm_mode',
        )
        group.add_argument(
            "--fp8-reuse-quantized-weight",
            action="store_true",
            default=False,
            help="Reuse quantized FP8 weight tensors within one optimizer step.",
        )

    def validate_args(self, args):
        if args.fp8 and args.transformer_impl == 'local':
            raise AssertionError('FP8 just support TE implement.')
        if args.use_ascend_coc and args.transformer_impl == 'transformer_engine':
            raise AssertionError('transformer engine does not support ascend coc')
        if args.use_ascend_mc2 and args.fp8 and args.fp8_recipe != 'mxfp8':
            raise AssertionError('MC2 is supported only by the mxfp8 recipe in fp8.')
        if getattr(args, "transformer_impl", "transformer_engine") == "transformer_engine" and getattr(
            args, "use_legacy_models", False
        ):
            raise AssertionError('transformer engine only support for mcore models')
        if args.fp8 == 'hif8':
            if args.fp8_recipe != 'tensorwise':
                raise ValueError("hif8 only support tensorwise scaling type")
        if args.use_gmm_fp8:
            if args.fp8_recipe not in ('mxfp8', 'mxfp8-32x32', 'tensorwise', 'delayed'):
                warnings.warn(
                    f"gmm fp8 only supports tensorwise, mxfp8, mxfp8-32x32, and delayed recipe, but {args.fp8_recipe} provided, "
                    f"using bf16 gmm instead."
                )
        if getattr(args, "fp8_reuse_quantized_weight", False) and not args.fp8:
            raise ValueError("fp8_reuse_quantized_weight is only valid when FP8 training is enabled")

    def pre_register_patches(self, patch_manager, args):
        patch_manager.register_patch(
            'transformer_engine.pytorch.tensor.QuantizedTensor', torch.nn.Module, create_dummy=True
        )

    def register_patches(self, patch_manager: MindSpeedPatchesManager, args):
        from mindspeed.te.pytorch.module.checkpoint import (
            transformer_block_forward,
            transformer_block_checkpointed_forward,
        )

        patch_manager.register_patch(
            'megatron.core.transformer.transformer_block.TransformerBlock.forward', transformer_block_forward
        )
        # Keep the existing patch order for other components.
        if not (getattr(args, 'swap_attention', False) or getattr(args, 'recompute_method', False) == 'block'):
            patch_manager.register_patch(
                'megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
                transformer_block_checkpointed_forward,
            )

        if not getattr(args, 'te_gmm_mode', 'compatible') == 'performance':
            from mindspeed.te.pytorch.module.grouped_linear import (
                MindSpeedTEGroupedLinear,
                MindSpeedTEColumnParallelGroupedLinear,
                MindSpeedTERowParallelGroupedLinear,
            )

            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TEGroupedLinear', MindSpeedTEGroupedLinear
            )
            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
                MindSpeedTEColumnParallelGroupedLinear,
            )
            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
                MindSpeedTERowParallelGroupedLinear,
            )
        else:
            from mindspeed.te.pytorch.module.performance_grouped_linear import (
                MindSpeedTEPerformanceGroupedLinear,
                MindSpeedTEPerformanceColumnParallelGroupedLinear,
                MindSpeedTEPerformanceRowParallelGroupedLinear,
            )

            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TEGroupedLinear', MindSpeedTEPerformanceGroupedLinear
            )
            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TEColumnParallelGroupedLinear',
                MindSpeedTEPerformanceColumnParallelGroupedLinear,
            )
            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TERowParallelGroupedLinear',
                MindSpeedTEPerformanceRowParallelGroupedLinear,
            )

        if getattr(args, "fp8_format", False):
            from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import (
                MindSpeedTEDotProductAttention,
            )
            from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import (
                MindSpeedTELayerNormColumnParallelLinear,
            )
            from mindspeed.te.pytorch.module.grouped_linear import (
                MindSpeedTEGroupedLinear,
                MindSpeedTEColumnParallelGroupedLinear,
                MindSpeedTERowParallelGroupedLinear,
            )
            from mindspeed.te.pytorch.module.linear import TERowParallelLinear, TEColumnParallelLinear
            from mindspeed.te.pytorch.fp8.constants import Format, Fp8Recipe
            from mindspeed.core.fp8_utils import get_fp8_context
            from mindspeed.te.pytorch.fp8.fp8 import fp8_autocast, fp8_model_init
            from mindspeed.te.pytorch.fp8.recipes import Float8CurrentScaling, MXFP8BlockScaling, TEDelayedScaling
            from mindspeed.te.pytorch.fp8.padding import Fp8Padding, Fp8Unpadding

            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TEColumnParallelLinear', TEColumnParallelLinear
            )
            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TERowParallelLinear', TERowParallelLinear
            )

            if int(getattr(args, 'context_parallel_size', 1)) == 1:
                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TEDotProductAttention', MindSpeedTEDotProductAttention
                )

            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
                MindSpeedTELayerNormColumnParallelLinear,
            )

            patch_manager.register_patch('transformer_engine.common.recipe.Format', Format)
            patch_manager.register_patch('megatron.core.enums.Fp8Recipe', Fp8Recipe)

            patch_manager.register_patch('megatron.core.fp8_utils.get_fp8_context', get_fp8_context)
            patch_manager.register_patch('transformer_engine.pytorch.fp8_model_init', fp8_model_init)
            patch_manager.register_patch('transformer_engine.pytorch.fp8_autocast', fp8_autocast)
            patch_manager.register_patch("transformer_engine.common.recipe.Float8CurrentScaling", Float8CurrentScaling)
            patch_manager.register_patch('transformer_engine.common.recipe.MXFP8BlockScaling', MXFP8BlockScaling)
            patch_manager.register_patch(
                "megatron.core.extensions.transformer_engine.TEDelayedScaling", TEDelayedScaling
            )
            patch_manager.register_patch("megatron.core.extensions.transformer_engine.Fp8Padding", Fp8Padding)
            patch_manager.register_patch("megatron.core.extensions.transformer_engine.Fp8Unpadding", Fp8Unpadding)

            from mindspeed.te.pytorch.module.checkpoint import te_checkpoint

            patch_manager.register_patch('megatron.core.extensions.transformer_engine.te_checkpoint', te_checkpoint)

            if not getattr(args, "moe_fb_overlap", False):
                from mindspeed.core.transformer.moe.moe_feature.fb_overlap.adaptor import (
                    dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap,
                    get_moe_module_spec_wrapper,
                )

                patch_manager.register_patch(
                    'megatron.core.models.gpt.moe_module_specs.get_moe_module_spec', get_moe_module_spec_wrapper
                )
                if getattr(args, 'mtp_num_layers', None):
                    patch_manager.register_patch(
                        'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.forward',
                        dualpipev_fb_overlap_mtp_layer_forward_te_without_overlap,
                    )
            if getattr(args, "fp8_reuse_quantized_weight", False):
                init_weight_quantization_reuse(patch_manager, args)
        else:
            from mindspeed.te.pytorch.attention.dot_product_attention.dot_product_attention import (
                MindSpeedTEDotProductAttention,
            )
            from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
            from mindspeed.te.pytorch.module.layernorm_column_parallel_linear import (
                MindSpeedTELayerNormColumnParallelLinear,
            )

            if not getattr(args, 'use_ascend_mc2', False):
                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TEColumnParallelLinear', ColumnParallelLinear
                )
                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TERowParallelLinear', RowParallelLinear
                )
            else:
                from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2ColumnParallelLinear
                from mindspeed.core.tensor_parallel.mc2_feature.adaptor import MindSpeedMC2RowParallelLinear

                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TEColumnParallelLinear',
                    MindSpeedMC2ColumnParallelLinear,
                )
                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TERowParallelLinear', MindSpeedMC2RowParallelLinear
                )

            if int(getattr(args, 'context_parallel_size', 1)) == 1:
                patch_manager.register_patch(
                    'megatron.core.extensions.transformer_engine.TEDotProductAttention', MindSpeedTEDotProductAttention
                )

            patch_manager.register_patch(
                'megatron.core.extensions.transformer_engine.TELayerNormColumnParallelLinear',
                MindSpeedTELayerNormColumnParallelLinear,
            )