from argparse import ArgumentParser

from mindspeed.features_manager.feature import MindSpeedFeature


class SwapOptimizerFeature(MindSpeedFeature):
    def __init__(self):
        super().__init__('swap-optimizer')

    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)
        group.add_argument('--swap-optimizer', action='store_true', help='swap optimizer to cpu')
        group.add_argument(
            '--swap-optimizer-times',
            type=int,
            default=16,
            help='Each swap will be moved (len(shard_fp32_from_float16) // swap_optimizer_times) elements',
        )

    def validate_args(self, args):
        self.incompatible_check(args, 'reuse_fp32_param')
        if (
            getattr(args, self.feature_name, None)
            and not getattr(args, "use_distributed_optimizer", None)
            and not getattr(args, "use_layer_wise_distributed_optimizer", None)
        ):
            raise ValueError(
                "Swap-optimizer only support use_distributed_optimizer/use_layer_wise_distributed_optimizer"
            )

    def register_patches(self, patch_manager, args):
        if getattr(args, self.feature_name, None):
            if 'adam' in getattr(args, 'optimizer', 'adam'):
                from mindspeed.core.optimizer.swap_optimizer.swap_optimizer import (
                    SwapDistributedOptimizer,
                    swap_adamw_step,
                )

                patch_manager.register_patch(
                    'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer', SwapDistributedOptimizer
                )
                patch_manager.register_patch('mindspeed.core.optimizer.adamw.AdamW.step', swap_adamw_step)
            elif 'muon' in getattr(args, 'optimizer', 'adam'):
                from mindspeed.core.optimizer.swap_muon.swap_muon import (
                    swap_layer_wise_distributed_optimizer_init_wrapper,
                    swap_muon_step,
                )

                patch_manager.register_patch(
                    'mindspeed.core.optimizer.muon.layer_wise_optimizer.LayerWiseDistributedOptimizer.__init__',
                    swap_layer_wise_distributed_optimizer_init_wrapper,
                )
                patch_manager.register_patch(
                    'mindspeed.core.optimizer.muon.emerging_optimizers.TensorParallelMuon.step', swap_muon_step
                )