from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class HighAvailabilityFeature(MindSpeedFeature):
def __init__(self):
super(HighAvailabilityFeature, self).__init__(feature_name='high-availability', optimization_level=0)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--enable-high-availability', action='store_true',
help='switch of the high availability feature')
group.add_argument('--enable-hbmfault-repair', action='store_true',
help='high availability feature, enable hbmfault repair')
group.add_argument('--enable-worker-reboot', action='store_true',
help='high availability feature, enable worker reboot')
group.add_argument('--distributed-optimizer-no-replica', action='store_true',
help='high availability feature, repair from ckpt and disable replica optimizer')
group.add_argument('--enable-elastic-training', action='store_true',
help='high availability feature, enable elastic training')
def pre_validate_args(self, args):
from mindspeed_llm.tasks.high_availability.high_availability_helper import get_env_args
get_env_args(args)
def validate_args(self, args):
if args.enable_high_availability:
try:
import mindio_ttp
except ModuleNotFoundError as e:
raise AssertionError(
f"High availability feature requires the mindio_ttp package but is not installed.") from e
if args.enable_hbmfault_repair and not args.enable_high_availability:
raise AssertionError(
'switch of the enable hbmfault repair is unsupported, please enable high availability feature first.')
if args.enable_high_availability and args.use_dist_ckpt:
raise AssertionError('switch of the high availability feature is unsupported')
if args.enable_high_availability and args.swap_attention:
raise AssertionError(
'switch of the high availability feature is unsupported, please disable swap attention first.')
if args.enable_high_availability and args.disable_gloo_group:
raise AssertionError(
'switch of the high availability feature is unsupported, please disable disable-gloo-group first.')
if args.swap_optimizer and args.enable_high_availability:
raise AssertionError('switch of the high availability feature is unsupported')
if args.enable_elastic_training and not args.enable_high_availability:
raise AssertionError(
'switch of the enable elastic training is unsupported, please enable high availability feature first.')
if args.enable_elastic_training and not args.use_distributed_optimizer:
raise AssertionError(
'switch of the enable elastic training is unsupported, please enable use-distributed-optimizer first.')
if args.enable_elastic_training and args.use_custom_fsdp:
raise AssertionError(
'switch of the enable elastic training is unsupported when reuse-fp32-param is enabled.')
if args.enable_elastic_training and args.reuse_fp32_param:
raise AssertionError(
'switch of the enable elastic training is unsupported when reuse-fp32-param is enabled.')
if args.enable_elastic_training and (args.expert_model_parallel_size > 1 or args.context_parallel_size > 1):
raise AssertionError(
'switch of the enable elastic training is unsupported when expert-model-parallel-size, context '
'parallel size is set.')
if args.enable_high_availability and args.lora_target_modules:
raise AssertionError(
'switch of the high availability feature is unsupported, please disable lora-target-modules first.')
def pre_register_patches(self, patch_manager, args):
from mindspeed_llm.tasks.high_availability.communication_patch import communication_wrapper, barrier_wrapper
from mindspeed_llm.tasks.high_availability.high_availability_helper import skip_reuse_register_patches
patch_manager.register_patch('torch.distributed.barrier',
barrier_wrapper)
for communication in ['all_reduce', '_all_gather_base', 'broadcast', 'all_gather_into_tensor']:
patch_manager.register_patch('torch.distributed.distributed_c10d.' + communication,
communication_wrapper)
from mindspeed_llm.tasks.high_availability.communication_patch import (group_index_two_torch_wrapper,
group_index_three_torch_wrapper, all_to_all_single_wrapper)
patch_manager.register_patch('torch.distributed.all_to_all_single',
all_to_all_single_wrapper)
for communication in ['all_gather', 'all_to_all', 'all_reduce_coalesced', 'all_gather_object',
'broadcast_object_list', 'all_gather_coalesced', 'irecv', 'isend']:
patch_manager.register_patch('torch.distributed.' + communication,
group_index_two_torch_wrapper)
for communication in ['gather', 'scatter', 'reduce', 'reduce_scatter', 'gather_object',
'scatter_object_list', 'reduce_scatter_tensor', '_reduce_scatter_base']:
patch_manager.register_patch('torch.distributed.' + communication,
group_index_three_torch_wrapper)
from mindspeed.features_manager import ReuseFP32Param
ReuseFP32Param.register_patches = skip_reuse_register_patches(ReuseFP32Param.register_patches, args)
def register_patches(self, patch_manager, args):
from mindspeed_llm.tasks.high_availability.initialize_patch import initialize_distributed_wrapper
from mindspeed_llm.core import (start_grad_sync_wrapper,
start_param_sync_wrapper, param_and_grad_bucket_group_init_wrapper,
get_megatron_optimizer_wrapper, get_grad_norm_fp32_wrapper,
distributed_optimizer_init_wrapper,
distributed_optimizer_init_for_reuse_fp32_wrapper,
get_parameter_state_dp_zero_with_high_availability_wrapper)
from mindspeed_llm.core.pipeline_parallel.schedules import high_availability_get_forward_backward_func_wrapper
if args.enable_high_availability:
no_replica = getattr(args, 'distributed_optimizer_no_replica', False)
if not no_replica:
patch_manager.register_patch('megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.start_grad_sync',
start_grad_sync_wrapper)
patch_manager.register_patch('megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.__init__',
param_and_grad_bucket_group_init_wrapper)
patch_manager.register_patch('megatron.core.distributed.param_and_grad_buffer._ParamAndGradBucketGroup.start_param_sync',
start_param_sync_wrapper)
patch_manager.register_patch('megatron.training.training.get_megatron_optimizer',
get_megatron_optimizer_wrapper)
patch_manager.register_patch('megatron.training.initialize._initialize_distributed',
initialize_distributed_wrapper)
patch_manager.register_patch('megatron.core.optimizer.clip_grads.get_grad_norm_fp32',
get_grad_norm_fp32_wrapper)
patch_manager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
distributed_optimizer_init_wrapper)
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func',
high_availability_get_forward_backward_func_wrapper)
from mindspeed_llm.core.high_availability.tft_acp_compatibility import (
distrib_optimizer_load_parameter_state_patch, chained_optimizer_load_parameter_state_patch,
checkpointing_load_base_checkpoint_patch, initialize_model_parallel_wrapper)
patch_manager.register_patch('megatron.core.optimizer.optimizer.ChainedOptimizer.load_parameter_state',
chained_optimizer_load_parameter_state_patch)
patch_manager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.load_parameter_state',
distrib_optimizer_load_parameter_state_patch)
patch_manager.register_patch('megatron.training.checkpointing._load_base_checkpoint',
checkpointing_load_base_checkpoint_patch)
patch_manager.register_patch('megatron.core.parallel_state.initialize_model_parallel',
initialize_model_parallel_wrapper)
if args.reuse_fp32_param:
from mindspeed.core.memory.reuse_param.adaptor import reuse_fp32_param_init_wrapper, optimizer_config_init_wrapper
patch_manager.register_patch('megatron.core.optimizer.optimizer.Float16OptimizerWithFloat16Params.__init__',
reuse_fp32_param_init_wrapper)
patch_manager.register_patch('megatron.core.optimizer.optimizer_config.OptimizerConfig.__init__',
optimizer_config_init_wrapper)
patch_manager.register_patch('megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.__init__',
distributed_optimizer_init_for_reuse_fp32_wrapper)
patch_manager.register_patch('mindspeed_llm.core.high_availability.TTPReplicaOptimizer.get_parameter_state_dp_zero_for_ttp',
get_parameter_state_dp_zero_with_high_availability_wrapper)
if args.enable_worker_reboot or args.enable_elastic_training:
from mindspeed_llm.tasks.high_availability.initialize_patch import build_train_valid_test_data_iterators_wrapper
from mindspeed_llm.tasks.high_availability.communication_patch import new_group_wrapper
patch_manager.register_patch('megatron.training.training.build_train_valid_test_data_iterators',
build_train_valid_test_data_iterators_wrapper)
patch_manager.register_patch('torch.distributed.distributed_c10d.new_group',
new_group_wrapper)
if args.enable_elastic_training:
from mindspeed_llm.core.pipeline_parallel.schedules import forward_step_wrapper
from mindspeed_llm.core.optimizer.distrib_optimizer import get_parameter_state_dp_zero_wrapper
from mindspeed_llm.core.timers import patch_world_size_func_wrapper, log_wrapper
from mindspeed_llm.training.utils import is_last_rank_wrapper, print_rank_last_wrapper
from mindspeed_llm.core.optimizer_param_scheduler import optimizer_param_scheduler_step_wrapper
from mindspeed_llm.core.pipeline_parallel.schedules import (
elastic_training_get_forward_backward_func_wrapper)
from mindspeed_llm.training.training import num_floating_point_operations_wrapper
from mindspeed_llm.training.one_logger_utils import track_app_tag_wrapper
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_step',
forward_step_wrapper)
patch_manager.register_patch(
'megatron.core.optimizer.distrib_optimizer.DistributedOptimizer.get_parameter_state_dp_zero',
get_parameter_state_dp_zero_wrapper)
patch_manager.register_patch('megatron.core.timers.Timers._get_elapsed_time_all_ranks',
patch_world_size_func_wrapper)
patch_manager.register_patch('megatron.core.timers.Timers._get_all_ranks_time_string',
patch_world_size_func_wrapper)
patch_manager.register_patch('megatron.core.timers.Timers.log',
log_wrapper)
patch_manager.register_patch('megatron.training.utils.is_last_rank',
is_last_rank_wrapper)
patch_manager.register_patch('megatron.core.optimizer_param_scheduler.OptimizerParamScheduler.step',
optimizer_param_scheduler_step_wrapper)
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_forward_backward_func',
elastic_training_get_forward_backward_func_wrapper)
patch_manager.register_patch('megatron.training.one_logger_utils.track_app_tag',
track_app_tag_wrapper)
patch_manager.register_patch('megatron.training.training.num_floating_point_operations',
num_floating_point_operations_wrapper)
patch_manager.register_patch('megatron.training.utils.print_rank_last',
print_rank_last_wrapper)