beba3143创建于 2025年7月31日历史提交
import time
from argparse import ArgumentParser

from mindspeed.features_manager.feature import MindSpeedFeature

SWAP_DELAY = 10


class SmartSwapFeature(MindSpeedFeature):

    def __init__(self):
        super().__init__('smart-swap', 2)

    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)
        group.add_argument('--smart-swap',
                           action='store_true', default=False, help='Enable the smart swap feature.')

    def validate_args(self, args):
        adaptive_recompute_enable = getattr(args, "adaptive_recompute_device_size", -1) > 0 or getattr(args, "adaptive-recompute-device-swap", False)
        if args.smart_swap and adaptive_recompute_enable:
            raise AssertionError('smart swap is not compatible with adaptive selective recompute')

    def register_patches(self, patch_manager, args):
        if getattr(args, self.feature_name, None):
            adaptive_recompute_enable = getattr(args, "adaptive_recompute_device_size", -1) > 0 or getattr(args, "adaptive-recompute-device-swap", False)
            if not adaptive_recompute_enable:
                from mindspeed.core.memory.smart_swap.swap_adaptor import change_allocator
                time.sleep(SWAP_DELAY)
                change_allocator()
                from mindspeed.core.memory.smart_swap.swap_megatron_adaptor import train_step_wrapper
                patch_manager.register_patch('megatron.training.training.train_step', train_step_wrapper)