import os
import warnings
import torch
from mindspeed.arguments import validate_args_wrapper
from megatron.core.rerun_state_machine import RerunStateMachine
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import get_torch_version, is_torch_min_version
from mindspeed.patch_utils import MindSpeedPatchesManager as pm
from megatron.training.arguments import load_retro_args, _check_arg_is_not_none, _print_args
from megatron.training import print_rank_0
from megatron.training.utils import get_device_arch_version, update_use_dist_ckpt
from mindspeed_mm.configs.config import merge_mm_args
from mindspeed_mm.utils.utils import ensure_valid
def safe_getattr(mm_object, mm_name, mm_default_value):
mm_value = getattr(mm_object, mm_name, mm_default_value)
if mm_value != mm_default_value:
print_rank_0(f'[INFO] the original value of {mm_name} is {mm_default_value}, now changed as {mm_value} which comes from model.json')
return mm_value
def validate_args(args, defaults=None):
if defaults is None:
defaults = {}
merge_mm_args(args)
if hasattr(args.mm.model, 'text_decoder'):
args.num_layers = safe_getattr(args.mm.model.text_decoder, 'num_layers', args.num_layers)
args.hidden_size = safe_getattr(args.mm.model.text_decoder, 'hidden_size', args.hidden_size)
args.num_attention_heads = safe_getattr(args.mm.model.text_decoder, 'num_attention_heads', args.num_attention_heads)
args.num_query_groups = safe_getattr(args.mm.model.text_decoder, 'num_query_groups', args.num_query_groups)
args.max_position_embeddings = safe_getattr(args.mm.model.text_decoder, 'max_position_embeddings', args.max_position_embeddings)
args.ffn_hidden_size = safe_getattr(args.mm.model.text_decoder, 'ffn_hidden_size', args.ffn_hidden_size)
if hasattr(args.mm.model.text_decoder, 'num_moe_experts'):
args.num_experts = safe_getattr(args.mm.model.text_decoder, 'num_moe_experts', args.num_experts)
args.n_shared_experts = safe_getattr(args.mm.model.text_decoder, 'n_shared_experts', args.n_shared_experts)
args.mm.model.text_decoder.moe_token_dispatcher_type = safe_getattr(args.mm.model.text_decoder, 'moe_token_dispatcher_type', args.moe_token_dispatcher_type)
args.add_bias_linear = safe_getattr(args.mm.model.text_decoder, 'add_bias_linear', True)
args.mm.model.text_decoder.v_head_dim = safe_getattr(args, 'v_head_dim', 0)
args.rope_scaling_type = safe_getattr(args.mm.model.text_decoder, 'rope_scaling_type', None)
if hasattr(args.mm.model, 'predictor'):
if hasattr(args.mm.model.predictor, 'mm_single_blocks_depth') and hasattr(args.mm.model.predictor, 'mm_double_blocks_depth'):
mm_double_blocks_depth = getattr(args.mm.model.predictor, 'mm_double_blocks_depth', 0)
if mm_double_blocks_depth <= 0:
raise AssertionError(f"MindSpeed-MM Error: mm_double_blocks_depth must > 0, actually:{mm_double_blocks_depth}")
args.num_layers = safe_getattr(args.mm.model.predictor, 'mm_single_blocks_depth', args.num_layers)
args.num_layers = safe_getattr(args.mm.model.predictor, 'num_layers', args.num_layers)
if isinstance(args.num_layers, (tuple, list)):
args.num_layers = sum(args.num_layers)
args.num_attention_heads = safe_getattr(args.mm.model.predictor, 'num_heads', args.num_attention_heads)
head_dim = getattr(args.mm.model.predictor, 'head_dim', 1)
if isinstance(head_dim, int):
hidden_size = args.num_attention_heads * head_dim
print_rank_0(f"[INFO] the original value of normalization is {args.hidden_size}, now changed as {hidden_size} which comes from model.json")
args.hidden_size = hidden_size
args.hidden_size = safe_getattr(args.mm.model.predictor, 'hidden_size', args.hidden_size)
args.attention_dropout = safe_getattr(args.mm.model.predictor, 'dropout', args.attention_dropout)
args.hidden_dropout = safe_getattr(args.mm.model.predictor, 'hidden_dropout', args.hidden_dropout)
args.swiglu = safe_getattr(args.mm.model.predictor, 'swiglu', args.swiglu)
args.masked_softmax_fusion = safe_getattr(args.mm.model.predictor, 'masked_softmax_fusion', args.masked_softmax_fusion)
norm_type = getattr(args.mm.model.predictor, 'qk_norm_type', "")
if isinstance(norm_type, str) and norm_type == "rmsnorm":
print_rank_0(f"[INFO] the original value of normalization is {args.normalization}, now changed as RMSNorm which comes from model.json")
args.normalization = "RMSNorm"
args.use_fused_rmsnorm = safe_getattr(args.mm.model.predictor, 'use_fused_rmsnorm', args.use_fused_rmsnorm)
args.seq_length = safe_getattr(args.mm.model.predictor, 'seq_length', 3072)
args.max_position_embeddings = safe_getattr(args.mm.model.predictor, 'max_position_embeddings', args.seq_length)
args.position_embedding_type = safe_getattr(args.mm.model.predictor, 'position_embedding_type', 'rope')
args.rotary_base = safe_getattr(args.mm.model.predictor, 'rotary_base', 500000)
args.tokenizer_type = safe_getattr(args.mm.model.predictor, 'tokenizer_type', "NullTokenizer")
args.vocab_size = safe_getattr(args.mm.model.predictor, "vocab_size", 0)
elif hasattr(args.mm.model, 'ae'):
args.num_layers = 1
args.hidden_size = 3072
args.num_attention_heads = 48
args.seq_length = 24
args.attention_dropout = 0.0
args.hidden_dropout = 0.0
args.swiglu = True
args.masked_softmax_fusion = False
args.max_position_embeddings = 24
args.position_embedding_type = "rope"
args.rotary_base = 500000
args.tokenizer_type = 'NullTokenizer'
args.vocab_size = 0
load_retro_args(args)
if args.use_legacy_models:
ensure_valid(args.ckpt_format == "torch", \
"legacy model format only supports the 'torch' checkpoint format.")
update_use_dist_ckpt(args)
if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0:
ensure_valid(args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder.")
if args.encoder_tensor_model_parallel_size > 0:
ensure_valid(args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0)
ensure_valid(args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder.")
if args.encoder_pipeline_model_parallel_size > 0 and args.encoder_tensor_model_parallel_size == 0:
args.encoder_tensor_model_parallel_size = args.tensor_model_parallel_size
encoder_model_size = args.encoder_tensor_model_parallel_size * args.encoder_pipeline_model_parallel_size * args.context_parallel_size
decoder_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size
total_model_size = encoder_model_size + decoder_model_size
ensure_valid(args.world_size % total_model_size == 0,
f"world size ({args.world_size}) is not divisible by total_model_size ({encoder_model_size=} + {decoder_model_size=})"
)
if args.attention_backend == AttnBackend.local:
ensure_valid(args.spec[0] == 'local', '--attention-backend local is only supported with --spec local')
args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size
args.data_parallel_size = args.world_size // total_model_size
if args.rank == 0:
print('using world size: {}, data-parallel size: {}, '
'context-parallel size: {}, '
'hierarchical context-parallel sizes: {}'
'tensor-model-parallel size: {}, '
'encoder-tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {}, '
'encoder-pipeline-model-parallel size: {}'.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
args.hierarchical_context_parallel_sizes,
args.tensor_model_parallel_size,
args.encoder_tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.encoder_pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_split_rank is not None:
args.encoder_pipeline_model_parallel_size = args.pipeline_model_parallel_split_rank
args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size
ensure_valid(args.pipeline_model_parallel_size > 0)
if args.hierarchical_context_parallel_sizes:
from numpy import prod
ensure_valid(args.context_parallel_size == prod(args.hierarchical_context_parallel_sizes))
if "a2a+p2p" in args.cp_comm_type:
ensure_valid(args.hierarchical_context_parallel_sizes is not None, \
"--hierarchical-context-parallel-sizes must be set when a2a+p2p is used in cp comm")
if args.expert_tensor_parallel_size is None:
args.expert_tensor_parallel_size = args.tensor_model_parallel_size
ensure_valid(args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead')
del args.batch_size
ensure_valid(args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead')
del args.warmup
ensure_valid(args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead')
del args.model_parallel_size
if args.checkpoint_activations:
if args.rank == 0:
print('--checkpoint-activations is no longer valid, use --recompute-activations, '
'or, for more control, --recompute-granularity and --recompute-method.')
exit()
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
for key in defaults:
if getattr(args, key, None) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
if args.data_path is not None and args.split is None:
legacy_default_split_value = '969, 30, 1'
if args.rank == 0:
print('WARNING: Please specify --split when using --data-path. Using legacy default value '
f'of "{legacy_default_split_value}"')
args.split = legacy_default_split_value
use_data_path = (args.data_path is not None) or (args.data_args_path is not None)
if use_data_path:
ensure_valid((args.data_path is None) or (args.data_args_path is None))
use_per_split_data_path = any(
elt is not None
for elt in [args.train_data_path, args.valid_data_path, args.test_data_path]) or \
args.per_split_data_args_path is not None
if use_per_split_data_path:
ensure_valid(any(elt is not None
for elt in [args.train_data_path, args.valid_data_path, args.test_data_path]) is False or \
args.per_split_data_args_path is None
)
ensure_valid(args.micro_batch_size is not None)
ensure_valid(args.micro_batch_size > 0)
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
ensure_valid(args.global_batch_size > 0)
ensure_valid(args.num_layers_per_virtual_pipeline_stage is None or args.num_virtual_stages_per_pipeline_rank is None, \
'--num-layers-per-virtual-pipeline-stage and --num-virtual-stages-per-pipeline-rank cannot be set at the same time')
if args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None:
raise AssertionError('MindSpeed-MM Error: --num-layers-per-virtual-pipeline-stage is deprecated, please use --virtual-pipeline-model-parallel-size instead')
if hasattr(args, 'virtual_pipeline_model_parallel_size') and args.virtual_pipeline_model_parallel_size is not None and args.virtual_pipeline_model_parallel_size > 1:
if args.overlap_p2p_comm:
ensure_valid(args.pipeline_model_parallel_size > 1, \
'when interleaved schedule is used, pipeline-model-parallel size '\
'should be greater than 1')
else:
ensure_valid(args.pipeline_model_parallel_size > 2, \
'when interleaved schedule is used and p2p communication overlap is disabled, '\
'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\
'p2p sends and recvs between same 2 ranks per communication batch')
if hasattr(args.mm.model, 'text_decoder'):
_pipeline_num_layers = getattr(args.mm.model.text_decoder, 'pipeline_num_layers', None)
if _pipeline_num_layers is None or len(_pipeline_num_layers) != args.virtual_pipeline_model_parallel_size:
raise AssertionError('MindSpeed-MM Error: vpp must enabled by --virtual-pipeline-model-parallel-size in shell and pipeline_num_layers in model.json, \
and virtual-pipeline-model-parallel-size must equal the length of pipeline_num_layers')
elif hasattr(args.mm.model, 'predictor'):
_pipeline_num_layers = getattr(args.mm.model.predictor, 'pipeline_num_layers', None)
if _pipeline_num_layers is None or len(_pipeline_num_layers) != args.virtual_pipeline_model_parallel_size:
raise AssertionError('MindSpeed-MM Error: vpp must enabled by --virtual-pipeline-model-parallel-size in shell and pipeline_num_layers in model.json, \
and virtual-pipeline-model-parallel-size must equal the length of pipeline_num_layers')
else:
raise AssertionError('MindSpeed-MM Error: vpp must enabled by --virtual-pipeline-model-parallel-size in shell and pipeline_num_layers in model.json')
else:
args.virtual_pipeline_model_parallel_size = None
args.overlap_p2p_comm = False
args.align_param_gather = False
if args.rank == 0 and args.pipeline_model_parallel_size > 1:
print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False '
'since non-interleaved schedule does not support overlapping p2p communication '
'and aligned param AG')
if args.rank == 0:
print(f"Number of virtual stages per pipeline stage: {args.virtual_pipeline_model_parallel_size}")
if args.data_parallel_sharding_strategy == "optim_grads_params":
args.overlap_param_gather = True
args.overlap_grad_reduce = True
if args.data_parallel_sharding_strategy == "optim_grads":
args.overlap_grad_reduce = True
if args.overlap_param_gather:
ensure_valid(args.use_distributed_optimizer, \
'--overlap-param-gather only supported with distributed optimizer')
ensure_valid(args.overlap_grad_reduce, \
'Must use --overlap-param-gather with --overlap-grad-reduce')
ensure_valid(not args.use_legacy_models, \
'--overlap-param-gather only supported with MCore models')
if args.use_torch_fsdp2:
ensure_valid(is_torch_min_version("2.4.0"), \
'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.')
ensure_valid(args.pipeline_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with pipeline parallelism')
ensure_valid(args.expert_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with expert parallelism')
ensure_valid(not args.use_distributed_optimizer, \
"--use-torch-fsdp2 is not supported with MCore's distributed optimizer")
ensure_valid(not args.gradient_accumulation_fusion, \
'--use-torch-fsdp2 is not supported with gradient accumulation fusion')
ensure_valid(args.ckpt_format in ('torch_dist', 'torch_dcp'), \
'--use-torch-fsdp2 requires --ckpt-format torch_dist or torch_dcp')
ensure_valid(args.untie_embeddings_and_output_weights, \
'--use-torch-fsdp2 requires --untie-embeddings-and-output-weights')
ensure_valid(not args.fp16, \
'--use-torch-fsdp2 not supported with fp16 yet')
ensure_valid(os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one')
if args.overlap_param_gather_with_optimizer_step:
ensure_valid(args.use_distributed_optimizer, \
'--overlap-param-gather-with-optimizer-step only supported with distributed optimizer')
ensure_valid(args.overlap_param_gather, \
'Must use --overlap-param-gather-with-optimizer-step with --overlap-param-gather')
ensure_valid(args.virtual_pipeline_model_parallel_size is not None, \
'--overlap-param-gather-with-optimizer-step only supported with interleaved pipeline parallelism')
ensure_valid(not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet')
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
def map_dtype(d):
if isinstance(d, torch.dtype):
return d
else:
return dtype_map.get(d)
args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
if args.fp8_param_gather:
ensure_valid(args.use_distributed_optimizer or args.use_torch_fsdp2, \
'--fp8-param-gather only supported with distributed optimizer or torch fsdp2')
if args.use_custom_fsdp:
ensure_valid(args.use_distributed_optimizer, \
'--use-custom-fsdp only supported with distributed optimizer')
if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]:
warnings.warn('Please make sure your TransformerEngine support FSDP + gradient accumulation fusion')
ensure_valid(args.gradient_accumulation_fusion is False, \
"optim_grads_params optim_grads are not supported with gradient accumulation fusion")
if args.data_parallel_sharding_strategy == "optim_grads_params":
ensure_valid(args.check_weight_hash_across_dp_replicas_interval is None, \
'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params')
ensure_valid(os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1", \
'FSDP always requires CUDA_DEVICE_MAX_CONNECTIONS value large than one')
args.params_dtype = torch.float
if args.fp16:
ensure_valid(not args.bf16)
args.params_dtype = torch.half
if not args.loss_scale:
args.check_for_nan_in_loss_and_grad = False
if args.rank == 0:
print('WARNING: Setting args.check_for_nan_in_loss_and_grad to False since '
'dynamic loss scaling is being used')
if args.bf16:
ensure_valid(not args.fp16)
args.params_dtype = torch.bfloat16
if args.accumulate_allreduce_grads_in_fp32:
ensure_valid(args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set")
if args.grad_reduce_in_bf16:
args.accumulate_allreduce_grads_in_fp32 = False
elif not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
if args.dataloader_type is None:
args.dataloader_type = 'single'
ensure_valid(args.num_dataset_builder_threads > 0)
args.consumed_train_samples = 0
args.skipped_train_samples = 0
args.consumed_valid_samples = 0
args.variable_seq_lengths = False
if args.train_iters:
ensure_valid(args.train_samples is None, \
'expected iteration-based training')
ensure_valid(args.lr_decay_samples is None, \
'expected iteration-based learning rate decay')
ensure_valid(args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup')
ensure_valid(args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training')
if args.lr_warmup_fraction is not None:
ensure_valid(args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters')
if args.train_samples:
ensure_valid(args.train_iters is None, \
'expected sample-based training')
ensure_valid(args.lr_decay_iters is None, \
'expected sample-based learning rate decay')
ensure_valid(args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup')
if args.lr_warmup_fraction is not None:
ensure_valid(args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples')
if args.num_layers is not None:
ensure_valid(args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified')
args.encoder_num_layers = args.num_layers
else:
ensure_valid(args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified')
args.num_layers = args.encoder_num_layers
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
if args.ffn_hidden_size is None:
if args.swiglu:
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
else:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
ensure_valid(args.hidden_size % args.num_attention_heads == 0)
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None and args.context_parallel_size > 1:
ensure_valid(args.seq_length % (args.context_parallel_size * 2) == 0, \
'seq-length should be a multiple of 2 * context-parallel-size ' \
'if context-parallel-size > 1.')
if args.seq_length is not None:
ensure_valid(args.encoder_seq_length is None)
args.encoder_seq_length = args.seq_length
else:
ensure_valid(args.encoder_seq_length is not None)
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
ensure_valid(args.max_position_embeddings >= args.seq_length, \
f"max_position_embeddings ({args.max_position_embeddings}) must be greater than " \
f"or equal to seq_length ({args.seq_length}).")
if args.decoder_seq_length is not None:
ensure_valid(args.max_position_embeddings >= args.decoder_seq_length)
if args.lr is not None:
ensure_valid(args.min_lr <= args.lr)
if args.save is not None:
ensure_valid(args.save_interval is not None)
if args.fp16_lm_cross_entropy:
ensure_valid(args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.')
if args.fp32_residual_connection:
ensure_valid(args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.')
if args.moe_grouped_gemm:
ensure_valid(args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.')
dc = torch.cuda.get_device_capability()
ensure_valid(dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels.")
if args.weight_decay_incr_style == 'constant':
ensure_valid(args.start_weight_decay is None)
ensure_valid(args.end_weight_decay is None)
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
ensure_valid(args.start_weight_decay is not None)
ensure_valid(args.end_weight_decay is not None)
if not is_torch_min_version("1.11.0a0"):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
if args.distribute_saved_activations:
ensure_valid(args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups')
ensure_valid(args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity')
ensure_valid(args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method ')
ensure_valid(is_torch_min_version("1.10.0a0"), \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
f'pytorch version is v{get_torch_version()}.')
if args.recompute_granularity == 'selective':
ensure_valid(args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity')
if args.tensor_model_parallel_size == 1:
if args.sequence_parallel:
warnings.warn("Disabling sequence parallelism because tensor model parallelism is disabled")
args.sequence_parallel = False
if args.tp_comm_overlap:
ensure_valid(args.sequence_parallel, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled')
if args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1 and get_device_arch_version() < 10:
if args.use_torch_fsdp2 or args.use_custom_fsdp:
fsdp_impl = "Torch-FSDP2" if args.use_torch_fsdp2 else "Custom-FSDP"
warnings.warn(
f"Using tensor model parallelism or context parallelism with {fsdp_impl} together. "
"Try not to using them together since they require different CUDA_MAX_CONNECTIONS "
"settings for best performance. sequence parallelism requires setting the "
f"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 while {fsdp_impl} "
"requires not setting CUDA_DEVICE_MAX_CONNECTIONS=1 for better parallelization.")
else:
ensure_valid(os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') == "1", \
"Using tensor model parallelism or context parallelism require setting the environment variable " \
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if not args.add_bias_linear:
args.bias_gelu_fusion = False
if args.add_bias_linear:
args.add_qkv_bias = True
if args.retro_add_retriever:
ensure_valid(args.train_samples is not None, \
"args.train_samples should be auto-loaded from the retro config.")
ensure_valid(not args.sequence_parallel, \
"retro currently does not support sequence parallelism.")
ensure_valid(args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism.")
if args.decoupled_lr is not None or args.decoupled_min_lr is not None:
ensure_valid(not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.')
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
if args.rotary_interleaved and args.apply_rope_fusion:
raise RuntimeError('--rotary-interleaved does not work with rope_fusion.')
if args.rotary_interleaved and args.use_legacy_models:
raise RuntimeError('--rotary-interleaved is not supported in legacy models.')
if args.position_embedding_type != 'rope':
args.apply_rope_fusion = False
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
if args.position_embedding_type == 'relative':
ensure_valid((
args.transformer_impl == "transformer_engine"
), 'Local transformer implementation currently does not support attention bias-based position embeddings.')
if args.position_embedding_type == "mrope":
ensure_valid(args.mrope_section is not None, \
'--mrope-section should be set when using --position-embedding-type mrope.')
if args.num_experts == 0:
args.num_experts = None
if args.num_experts is not None:
ensure_valid(args.spec is None, "Model Spec must be None when using MoEs")
if args.moe_ffn_hidden_size is None:
args.moe_ffn_hidden_size = args.ffn_hidden_size
if args.context_parallel_size > 1:
ensure_valid(not args.use_legacy_models, "Context parallelism is not supported in legacy models.")
if args.expert_model_parallel_size > 1:
ensure_valid(args.num_experts is not None, "num_experts must be non None to use expert model parallelism")
ensure_valid(args.num_experts % args.expert_model_parallel_size == 0, \
"Number of experts should be a multiple of expert model parallel_size.")
ensure_valid(not args.fp16, \
"Expert parallelism is not supported with fp16 training.")
if args.use_dist_ckpt and args.use_legacy_models:
raise RuntimeError('--use-dist-ckpt is not supported in legacy models.')
if args.ckpt_format == "torch_dcp":
ensure_valid(args.use_torch_fsdp2, "--ckpt-format torch_dcp is only tested with FSDP.")
ensure_valid(args.tensor_model_parallel_size <= 1 and args.encoder_tensor_model_parallel_size <= 1, \
"--ckpt-format torch_dcp is not tested with megatron tensor parallelism.")
ensure_valid(args.pipeline_model_parallel_size <= 1 and args.encoder_pipeline_model_parallel_size <= 1, \
"--ckpt-format torch_dcp is not tested with megatron pipeline parallelism.")
ensure_valid(args.mock_data + \
bool(args.data_path) + \
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
<= 1, "A single data source must be provided in training mode, else None")
if args.deterministic_mode:
ensure_valid(not args.use_flash_attn, "Flash attention can not be used in deterministic mode.")
ensure_valid(not args.cross_entropy_loss_fusion, "Cross Entropy Fusion is currently not deterministic.")
all_reduce_choices = ["Tree", "Ring", "CollnetDirect", "CollnetChain", "^NVLS"]
ensure_valid(os.getenv("NCCL_ALGO", -1) != -1 and os.getenv("NCCL_ALGO") in all_reduce_choices, \
f"NCCL_ALGO must be one of {all_reduce_choices}.")
torch.use_deterministic_algorithms(True)
if args.apply_query_key_layer_scaling:
args.attention_softmax_in_fp32 = True
if args.result_rejected_tracker_filename is not None:
iterations_to_skip_from_file = RerunStateMachine.get_skipped_iterations_from_tracker_file(
args.result_rejected_tracker_filename
)
args.iterations_to_skip.extend(iterations_to_skip_from_file)
if not args.enable_gloo_process_groups:
if args.use_distributed_optimizer:
ensure_valid(args.use_dist_ckpt)
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
print('--ckpt-fully-parallel-save flag is deprecated and has no effect.'
' Use --no-ckpt-fully-parallel-save to disable parallel save.')
use_dist_ckpt_and_not_ckpt_fully_parallel_save = args.use_dist_ckpt and not args.ckpt_fully_parallel_save
use_distributed_optimizer_and_rank = args.use_distributed_optimizer and args.rank == 0
if use_dist_ckpt_and_not_ckpt_fully_parallel_save and use_distributed_optimizer_and_rank:
print('Warning: With non-parallel ckpt save and DistributedOptimizer,'
' it will be impossible to resume training with different parallelism.'
' Consider removing flag --no-ckpt-fully-parallel-save.')
if args.use_dist_ckpt_deprecated and args.rank == 0:
print('--use-dist-ckpt is deprecated and has no effect.'
' Use --ckpt-format to select the checkpoint format.')
if args.dist_ckpt_format_deprecated and args.rank == 0:
print('--dist-ckpt-format is deprecated and has no effect.'
' Use --ckpt-format to select the checkpoint format.')
if args.inference_batch_times_seqlen_threshold > -1:
ensure_valid(args.pipeline_model_parallel_size > 1, \
"--inference-batch-times-seqlen-threshold requires setting --pipeline-model-parallel-size > 1.")
if args.inference_dynamic_batching:
ensure_valid(args.inference_dynamic_batching_buffer_size_gb is not None)
ensure_valid(args.inference_dynamic_batching_buffer_guaranteed_fraction is not None)
if args.moe_use_upcycling:
ensure_valid(args.save is not None, "When using upcycling, the --save option must be specified.")
if not args.no_load_optim:
args.no_load_optim = True
print('Warning: disabling --no-load-optim for upcycling.')
if not args.no_load_rng:
args.no_load_rng = True
print('Warning: disabling --no-load-rng for upcycling.')
if args.optimizer_cpu_offload:
ensure_valid(args.use_precision_aware_optimizer, (
"The optimizer cpu offload must be used in conjunction with `--use-precision-aware-optimizer`, "
"as the hybrid device optimizer reuses the code path of this flag."
))
if args.non_persistent_ckpt_type == "local":
ensure_valid(args.non_persistent_local_ckpt_dir is not None, "Tried to use local checkpointing without specifying --local-ckpt-dir!")
if args.replication:
ensure_valid(args.replication_jump is not None, "--replication requires the value of --replication-jump!")
ensure_valid(args.non_persistent_ckpt_type == "local", f"--replication requires args.non_persistent_ckpt_type == 'local', but got: {args.non_persistent_ckpt_type}")
elif args.replication_jump:
print("Warning: --replication-jump was specified despite not using replication. Ignoring.")
args.replication_jump = None
if args.mtp_num_layers:
ensure_valid(not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP).")
ensure_valid(args.context_parallel_size == 1, "Multi-Token Prediction (MTP) is not supported with Context Parallelism.")
ensure_valid(args.position_embedding_type == "rope" or args.position_embedding_type == "none", (
f"Multi-Token Prediction (MTP) is not supported with {args.position_embedding_type} position embedding type."
+ f"The supported position embedding types are rope and none."
))
_print_args("arguments", args)
ensure_valid(not (getattr(args, 'calculate_per_sample_loss', False) and getattr(args, 'calculate_per_token_loss', False)),
("calculate_per_sample_loss and calculate_per_token_loss cannot be used simultaneously; "
"only one of them can be enabled. If neither is specified, loss computation will fall back to the default behavior. "
"For more details, please refer to the relevant documentation: vlm_model_loss_calculate_type.md"
))
return args
pm.register_patch("megatron.training.arguments.validate_args", validate_args, force_patch=True)
pm.register_patch("mindspeed_mm.patchs.validate_args_patch.validate_args", validate_args_wrapper)
pm.apply_patches()