import os
import argparse
from pathlib import Path
from functools import wraps
from mindspeed.features_manager.features_manager import MindSpeedFeaturesManager
from mindspeed_llm.features_manager import FEATURES_LIST
cur_file_dir = Path(__file__).absolute().parent
TEMPLATES_DIR = os.path.join(cur_file_dir.parent.parent, "configs/finetune/templates.json")
def print_rank0_by_args(args, message):
"""Before initialization of distributed, we only print on rank 0."""
if args.rank == 0:
print(message, flush=True)
def extra_args_provider_decorator(extra_args_provider):
@wraps(extra_args_provider)
def wrapper(parser):
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = process_args_v2(parser)
return parser
return wrapper
def parse_args_decorator(parse_args):
@wraps(parse_args)
def wrapper(extra_args_provider=None, ignore_unknown_args=False):
decorated_provider = extra_args_provider_decorator(extra_args_provider)
return parse_args(decorated_provider, ignore_unknown_args)
return wrapper
def process_args(parser):
parser.conflict_handler = 'resolve'
parser = _add_fusion_op_args(parser)
parser = _add_network_size_args(parser)
parser = _add_lora_args(parser)
parser = _add_lu_lora_args(parser)
parser = _add_data_args(parser)
parser = _add_moe_args(parser)
parser = _add_num_layer_allocation(parser)
parser = _add_profile_args(parser)
parser = _add_network_args(parser)
parser = _add_training_args(parser)
parser = _add_distributed_args(parser)
parser = _add_algorithm_args(parser)
parser = _add_alibi_args(parser)
parser = _add_dataset_args(parser)
parser = _add_cp_args(parser)
parser = _add_mla_args(parser)
parser = _add_yarn_args(parser)
parser = _add_deepseek_moe_args(parser)
parser = _add_rl_args(parser)
parser = _add_ndmm_args(parser)
parser = _add_2d_tp_args(parser)
parser = _add_hccl_group_buffer_args(parser)
parser = _add_default_model_args(parser)
parser = _add_megatron2_args(parser)
parser = _add_inference_args(parser)
parser = _add_dualpipe_args(parser)
parser = _add_ai_framework_args(parser)
parser = _add_ckpt_args(parser)
parser = _add_zerc_args(parser)
for feature in FEATURES_LIST:
feature.register_args(parser)
return parser
def process_args_v2(parser):
MindSpeedFeaturesManager.register_features_args(parser)
return parser
def _add_zerc_args(parser):
group = parser.add_argument_group(title='zerc')
group.add_argument('--moe-zerc', action='store_true', default=False)
return parser
def _add_ai_framework_args(parser):
group = parser.add_argument_group(title='ai framework')
group.add_argument('--ai-framework', type=str, choices=['pytorch', 'mindspore'], default='pytorch', help='support pytorch and mindspore')
return parser
def _add_default_model_args(parser):
group = parser.add_argument_group(title='default model mode')
group.add_argument('--use-mcore-models', action='store_true', dest='use_mcore_models',
help='Use Megatron-Core models, will be DEPRECATED in future')
return parser
def _add_ckpt_args(parser):
group = parser.add_argument_group(title='ckpt arguments')
group.add_argument('--ckpt-format', default='torch',
choices=['torch', 'torch_dist', 'zarr'],
help='Checkpoint format to use.')
group.add_argument('--use-mp-args-from-checkpoint-args', action='store_true', default=True,
help='Copy model parallelism command-line arguments from checkpoint')
return parser
def _add_mla_args(parser):
group = parser.add_argument_group(title='multi-head latent attention')
group.add_argument('--padded-base-length', type=int, default=128,
help='Fill Q K V of multi-latent-attention to an integer multiple of this parameter.')
group.add_argument('--mla-fa-without-pad', action='store_true', default=False, help='Do not pad v_head_dim to q_head_dim in MLA')
group.add_argument('--mla-mm-split', action='store_true', default=False, help='Split 2 up-proj matmul into 4 in MLA')
group.add_argument("--mla-zero-memory", action='store_true', default=False, help="Save activation memory in multi-latent-attention.")
group.add_argument("--mla-up-proj-tp-overlap", action='store_true', default=False, help='overlap up proj tp comm')
group.add_argument("--recompute-mla-up-proj", action='store_true', default=False, help='recompute up projection in mla')
group.add_argument('--mla-swap-core-attn-out', action='store_true', default=False, help='swap core_attn_out only in mla.')
group.add_argument('--mla-fa-divide-qk', action='store_true', default=False,
help='Flash attn support mla with seperate q and k.')
return parser
def _add_yarn_args(parser):
group = parser.add_argument_group(title='yarn')
group.add_argument('--beta-fast', type=int, default=32, help='Yarn rope: rope beta fast')
group.add_argument('--beta-slow', type=int, default=1, help='Yarn rope: rope beta slow')
group.add_argument('--rope-scaling-factor', type=float, default=1.0, help='Yarn rope: rope factor')
group.add_argument('--rope-scaling-mscale', type=float, default=1.0, help='Yarn rope: rope mscale')
group.add_argument('--rope-scaling-mscale-all-dim', type=float, default=0.0, help='Yarn rope: rope mscale all dim')
group.add_argument('--rope-scaling-original-max-position-embeddings', type=int, default=None,
help='Yarn rope: rope original max position embeddings')
return parser
def _add_deepseek_moe_args(parser):
group = parser.add_argument_group(title='deepseek moe')
group.add_argument('--moe-intermediate-size', type=int, default=None, help='The ffn hidden size of MoE layer')
group.add_argument('--n-shared-experts', type=int, default=None,
help='This value is the number of shared experts, which is equal to the intermediate_size '
'of the shared experts divided by the moe_intermediate_size.')
group.add_argument('--first-k-dense-replace', type=int, default=None, help='Set first k layer as dense layer')
group.add_argument('--moe-layer-freq', type=int, default=None, help='Set the occurrence frequency of the moe layer')
return parser
def _add_dualpipe_args(parser):
group = parser.add_argument_group(title='dualpipe')
group.add_argument('--moe-fb-overlap', action='store_true', default=False)
group.add_argument('--schedules-method', type=str, default=None, choices=['dualpipev'])
group.add_argument('--dualpipev-dw-detach', action='store_true', help='detach dw in cooldown to reduce bubble')
group.add_argument('--moe-unperm2-mem-optim', action='store_true', default=False,
help='deallocate unperm2 activation memory by multiplying prob after act func.')
group.add_argument('--moe-unperm2-mem-optim-swap', action='store_true', default=False)
return parser
def _add_profile_args(parser):
group = parser.add_argument_group(title='profiler')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[-1],
help='Global ranks to profile.The default value of -1 means to profile all ranks')
group.add_argument('--profile-export-type', type=str, default='text',
choices=['text', 'db'], help='choose the export mode as text or db.')
group.add_argument('--profile-level', type=str, default='level0',
choices=['level_none', 'level0', 'level1', 'level2'], help='profiling level_none, level0, level1, level2')
group.add_argument('--profile-data-simplification', action='store_true', help='use data simplification mode')
group.add_argument('--profile-with-stack', action='store_true', help='profiling with stack info')
group.add_argument('--profile-with-memory', action='store_true', help='profiling with memory info')
group.add_argument('--profile-record-shapes', action='store_true', help='profiling with shape info')
group.add_argument('--profile-with-cpu', action='store_true', help='profiling with cpu info')
group.add_argument('--profile-save-path', type=str, default='./profile_dir',
help='path to save profiling files')
return parser
def _add_cp_args(parser):
group = parser.add_argument_group(title='cp parallel')
group.add_argument('--context-parallel-size', type=int, default=1,
help='Degree of context parallelism.')
group.add_argument('--context-parallel-algo', type=str, default='ulysses_cp_algo',
choices=['ulysses_cp_algo', 'megatron_cp_algo', 'hybrid_cp_algo', 'mamba_cp_algo'], help='context parallel algorithm')
group.add_argument('--ulysses-degree-in-cp', type=int, default=None)
group.add_argument('--attention-mask-type', type=str, default='causal',
choices=['causal', 'general'], help='context parallel attention mask type')
group.add_argument('--attention-mask-type', type=str, default='causal',
choices=['causal', 'general'], help='context parallel attention mask type')
group.add_argument('--use-cp-send-recv-overlap', action='store_true',
help='use it to enable cp send-recv-overlap.')
group.add_argument('--cp-window-size', type=int, default=1,
help='inner window size of double ring attention')
group.add_argument('--attention-mask-on-cpu', action='store_true',
help='store full attention mask on CPU instead of NPU')
group.add_argument('--adaptive-cp-without-coarse', action='store_true',
help='does not coarse the attention mask in adaptive_cp feature, only recommended when full'
'sequence length is less than 8K and dynamic attention mask is not feasible')
group.add_argument('--adaptive-cp-dynamic-attn-mask', action='store_true',
help='if the attention mask is dynamic across batches')
group.add_argument('--adaptive-cp-only-reschedule', action='store_true',
help='not apply remapping but only rescheduling process in adaptive-cp feature')
group.add_argument('--adaptive-cp-manually-set-mask-list', action='store_true',
help='manually set pre-cooked attention mask list')
group.add_argument('--kv-head-repeat-before-uly-alltoall', action='store_true', default=True,
help='use it to expand key and value for ulysses when GQA/MQA is used.')
return parser
def _validate_varlen_fa_args(args):
if args.reset_position_ids:
args.shape_order = 'TND'
print_rank0_by_args(args, f"When reset_position_ids is enabled, shape_order should be TND.")
def _validate_cp_args(args):
def _check_attention_head(args, uly_size):
"""
check GQA & ulysses
"""
head, remainder = divmod(args.num_attention_heads, uly_size * args.tensor_model_parallel_size)
if not (head >= 1 and remainder == 0):
raise ValueError(f"num_attention_heads must be divisible by ulysses_size * tensor_model_parallel_size")
if args.group_query_attention and args.num_query_groups >= 1:
head_split_by_tp, remainder = divmod(args.num_query_groups, args.tensor_model_parallel_size)
if not (head_split_by_tp >= 1 and remainder == 0):
raise ValueError(f"num_query_groups must be divisible by tensor_model_parallel_size")
if not args.kv_head_repeat_before_uly_alltoall:
head_split_by_tp_cp, remainder = divmod(head_split_by_tp, uly_size)
if not (head_split_by_tp_cp >= 1 and remainder == 0):
raise AssertionError(
'num_query_groups must be divisible by ulysses_size * tensor_model_parallel_size.\n'
'Solution 1. adjust the ulysses_size\n'
'Solution 2. You can enable --kv-head-repeat-before-uly-alltoall to roll on.\n'
'However, performance would be affected since it would increase communication volume \n'
'for ulysses alltoall as well as memory usage.')
if args.context_parallel_size <= 1:
if args.kv_head_repeat_before_uly_alltoall:
args.kv_head_repeat_before_uly_alltoall = False
print_rank0_by_args(args, f"When context_parallel is not activated, kv_head_repeat_before_uly_alltoall would be set to False for reducing memory usage.")
if args.use_fused_ring_attention_update:
raise AssertionError(f"fused_ring_attention_update only works when context parallel is activated.")
return
args.use_flash_attn = True
print_rank0_by_args(args, f"[INFO] Setting args.use_flash_attn={args.use_flash_attn} since context parallel is enabled.")
if not args.use_mcore_models:
raise AssertionError(f"Context parallel is only supported in Mcore.")
if args.context_parallel_algo == 'ulysses_cp_algo' or args.context_parallel_algo == 'mamba_cp_algo':
if args.seq_length % args.context_parallel_size != 0:
raise ValueError(f"sequence length must be divisible by context_parallel_size")
_check_attention_head(args, args.context_parallel_size)
if args.context_parallel_algo == 'megatron_cp_algo':
if args.seq_length % (2 * args.context_parallel_size) != 0:
raise ValueError(f"sequence length must be divisible by 2 * context_parallel_size")
if not (args.cp_window_size >= 1 and args.cp_window_size < args.context_parallel_size):
raise ValueError(f'cp_window_size should in range [1, context_parallel_size) when using double_ring_attention.')
n_window, remainder = divmod(args.context_parallel_size, args.cp_window_size)
if not (n_window >= 1 and remainder == 0):
raise ValueError(f'context parallel size must be divisible by cp_window_size when using double ring attention.')
if not (args.cp_window_size >= 1 and args.cp_window_size < args.context_parallel_size):
raise ValueError(f'cp_window_size should in range [1, context_parallel_size) when using double_ring_attention.')
n_window, remainder = divmod(args.context_parallel_size, args.cp_window_size)
if not (n_window >= 1 and remainder == 0):
raise ValueError(f'context parallel size must be divisible by cp_window_size when using double ring attention.')
if args.attention_mask_type == 'general':
if args.micro_batch_size != 1:
raise ValueError(f'When attention_mask_type is set to general, the value of mbs can only be 1.')
if args.context_parallel_algo == 'hybrid_cp_algo':
if args.ulysses_degree_in_cp is None:
raise ValueError("--ulysses-degree-in-cp must be specified in hybrid_cp_algo")
ring_degree, remainder = divmod(args.context_parallel_size, args.ulysses_degree_in_cp)
if not (ring_degree > 1 and remainder == 0):
raise ValueError("--ulysses-degree-in-cp must be divisible by --context-parallel-size")
if args.seq_length % (2 * args.context_parallel_size) != 0:
raise ValueError(f"sequence length must be divisible by 2 * context_parallel_size in hybrid cp")
if not (args.cp_window_size >= 1 and args.cp_window_size < ring_degree):
raise ValueError(f'cp_window_size should be in range [1, ring_degree) when using double ring attention with hybrid context parallelism.')
n_window, remainder = divmod(ring_degree, args.cp_window_size)
if not (n_window >= 1 and remainder == 0):
raise ValueError(f'ring_degree should be divisible by cp_window_size when using double ring with hybrid context parallelism.')
_check_attention_head(args, args.ulysses_degree_in_cp)
if args.attention_mask_type == 'general':
if args.micro_batch_size != 1:
raise ValueError(f'When attention_mask_type is set to general, the value of mbs can only be 1.')
if args.sliding_window:
raise AssertionError("sliding window is not supported in context parallel.")
def _validate_tocken(args):
"""To avoid invalid tocken configration."""
if args.pre_tocken > args.seq_length:
print_rank0_by_args(args, f"[INFO] pre_tocken={args.pre_tocken} would be adjusted to {args.seq_length} for better performance.")
if args.next_tocken > args.seq_length:
print_rank0_by_args(args, f"[INFO] next_tocken={args.next_tocken} would be adjusted to {args.seq_length} for better performance.")
def _add_lu_lora_args(parser):
group = parser.add_argument_group(title='LU-LoRA arguments')
group.add_argument('--lu-lora-final-layer-index', type=int, default=None,
help='Index of last transformer block trained with LU-LoRA.')
group.add_argument('--lu-lora-lr', type=float, default=1.25e-6,
help='Initial learning rate for LU-LoRA layers.')
group.add_argument('--lu-lora-lr-ratio', type=float, default=1.,
help='Ratio of learning rates of LU-LoRA B and LU-LoRA A adapters.')
return parser
def _add_lora_args(parser):
group = parser.add_argument_group(title='lora')
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
help='Lora target modules.')
group.add_argument('--lora-load', type=str, default=None,
help='Directory containing a lora model checkpoint.')
group.add_argument('--lora-r', type=int, default=16,
help='Lora r.')
group.add_argument('--lora-alpha', type=int, default=32,
help='Lora alpha.')
group.add_argument('--lora-modules-to-save', nargs='+', type=str, default=None,
help='Lora modules to save.')
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
default=['word_embeddings', 'input_layernorm'],
help='Lora register forward hook.')
group.add_argument('--lora-fusion', action='store_true',
help='use fusion to accelerate lora.')
group.add_argument('--lora-ckpt-filter', action='store_true', default=False,
help='Enable only saving lora checkpoint.')
group.add_argument('--qlora', action='store_true', default=False,
help='Enable QLoRA for fine-tuning with reduced memory usage.')
group.add_argument('--qlora-save-dequantize', action='store_true', default=False,
help='Dequantize weights to original precision when saving in QLoRA tuning.')
return parser
def _add_moe_args(parser):
group = parser.add_argument_group(title='moe')
group.add_argument('--expert-interval', type=int, default=1,
help='Use experts in every "expert-interval" layers')
group.add_argument('--moe-train-capacity-factor', type=float, default=1.0,
help='The capacity of the MoE expert at training time used in legacy moe layer called SwitchMLP.')
group.add_argument("--use-fused-moe-token-permute-and-unpermute", action='store_true',
help="Use fused moe permute and unpermute.")
group.add_argument("--gemm-gradient-accumulation-fusion", action='store_true',
help="Use gradient-accumulation-fusion in gemm.")
group.add_argument('--moe-token-dispatcher-type', type=str, choices=['allgather', 'alltoall', 'alltoall_seq'], default='allgather',
help='The dispatcher type for moe token dispatching.')
group.add_argument('--noisy-gate-policy', type=str, default=None,
help="noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.")
group.add_argument('--enable-token-rearrange-opt', action='store_true',
help="Use this flag to enable token rearrange optimize")
group.add_argument('--embedding-multiplier-scale', type=float, default=1.0,
help='add scale for embedding.')
group.add_argument('--input-jitter', action='store_false', help='Add noise to the input tensor.')
group.add_argument('--post-norm', action='store_true', help='post norm after attention or mlp.')
group.add_argument('--output-multiplier-scale', type=float, default=None, help='Add scale for logits output.')
group.add_argument("--moe-permutation-async-comm", action='store_true',
help="overlap moe permutation 3 all gather communications")
group.add_argument("--shared-expert-gate", action='store_true',
help="moe model has shared expert gate")
group.add_argument("--shared-expert-gate-output-dimension", type=int, default=1,
help="moe model shared expert gate output dimension for qwen2 moe, this parameter can only configured with"
"1 or hidden_state")
group.add_argument('--moe-alltoall-overlap-comm', action='store_true', default=False,
help='moe_alltoall_overlap_comm')
group.add_argument("--cla-share-factor", type=int, default=1,
help="Cross-Layer Attention share kv between cla-share-factor layers")
group.add_argument("--moe-tp-extend-ep", action='store_true',
help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group")
group.add_argument("--moe-zero-memory", type=str, default='disable',
choices=['disable', 'level0', 'level1'],
help="Save activation memory in moe layer.")
group.add_argument('--moe-zero-memory-num-layers', type=int, default=None,
help='the number of layers using moe-zero-memory level1'
'in each pp stage.')
group.add_argument('--moe-allgather-overlap-comm', action='store_true', default=False,
help='moe_allgather_overlap_comm')
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title='data')
group.add_argument('--is-instruction-dataset', action='store_true', help='use instruction dataset or not')
group.add_argument('--full-shuffle-instruction-dataset', action='store_true',
help='full shuffle instruction dataset or not')
group.add_argument('--variable-seq-lengths', action='store_true', help='Use variable seq lengths or not.')
group.add_argument("--tokenizer-kwargs", type=str, nargs='+', default=None,
help="Kwargs of the huggingface tokenizer.")
group.add_argument('--tokenizer-padding-side', type=str, default='right',
help="tokenizer padding side")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'Llama2Tokenizer',
'PretrainedFromHF',
'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")
group.add_argument("--tokenizer-not-use-fast", action='store_false',
help="HuggingFace tokenizer not use the fast version.")
group.add_argument("--input-layernorm-in-fp32", action='store_true',
help="Convert input-layernorm to fp32")
group.add_argument("--no-shuffle", action='store_true',
help="Disable data shuffling, mainly for loss comparison.")
group.add_argument('--neat-pack', action='store_true',
help='Use a zigzag attention mask.')
group.add_argument('--padded-samples', action='store_true',
help='fill in the missing samples within an epoch, '
'starting at index 0, aligned with the LlamaFatory.')
group.add_argument("--enable-thinking", type=lambda x: {"true": True, "false": False, "none": None}[x.lower()], default=None,
help="Whether or not to enable thinking mode for reasoning models.")
return parser
def _add_num_layer_allocation(parser):
group = parser.add_argument_group(title='num_layer_allocation')
group.add_argument('--num-layer-list',
type=str, help='a list of number of layers, '
'seperated by comma; e.g., 4,4,4,4')
return parser
def _add_fusion_op_args(parser):
group = parser.add_argument_group(title='fusion_op_args')
group.add_argument("--use-fused-rmsnorm", action='store_true',
help="Use fused rmsnorm.")
group.add_argument("--use-fused-swiglu", action='store_true',
help="Use fused swiglu.")
group.add_argument("--use-fused-rotary-pos-emb", action='store_true',
help="Use fused rotary-pos-emb.")
group.add_argument("--use-fused-ring-attention-update", action='store_true',
help="Use fused ring attention update.")
group.add_argument("--use-ascend-mc2", action='store_true',
help="Use mc2 for compute-comm overlap in tp.")
group.add_argument("--use-fused-mlp", action='store_true',
help="Use fused mlp.")
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network_size_args')
group.add_argument('--padded-vocab-size',
type=int,
default=None,
help='set padded vocab size')
group.add_argument('--embed-layernorm',
action='store_true',
default=False,
help='set padded vocab size'
)
group.add_argument('--use-glm-rope',
action='store_true',
help='use custom partial rope in glm model.'
)
group.add_argument('--sliding-window', type=int, default=None,
help='Window size when use sliding window attention.')
group.add_argument('--output-layer-slice-num', type=int, default=1,
help='Set the number of slices for the weight of the output_layer')
return parser
def _add_algorithm_args(parser):
group = parser.add_argument_group(title='algorithm')
group.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.')
group.add_argument('--reuse-fp32-param', action='store_true',
help='The distributed training optimizer frees up '
'param copies of FP32 to save memory.')
group.add_argument('--recompute-activation-function', action='store_true',
help='Recompute the activation function in MLP layers.')
group.add_argument('--recompute-activation-function-num-layers', type=int, default=None,
help='Can be used together with "--recompute-method block." '
'and "--recompute-num-layers". ')
group.add_argument('--recompute-in-advance', action='store_true',
help='recompute early to reduce bubble and improve training.')
group.add_argument('--recompute-norm', action='store_true',
help='Recompute norm in Transformer Layers')
group.add_argument('--recompute-norm-num-layers', type=int, default=None,
help='Recompute norm num layers, can be used together with activation function recompute. ')
group.add_argument('--o2-optimizer', action='store_true',
help='use bf16 exponential moving average to greatly save up memory.')
group.add_argument('--o2-gradient', action='store_true',
help='use bf16 gradient accumulation to greatly save up memory.')
group.add_argument('--share-kvstates', action='store_true',
help='CLA share kv states.')
return parser
def _add_network_args(parser):
group = parser.add_argument_group(title='network')
group.add_argument("--add-qkv-bias", action="store_true", default=False,
help='Configuration for the qkv bias.')
group.add_argument("--add-dense-bias", action="store_true", default=False,
help='Configuration for the dense bias.')
group.add_argument("--add-output-layer-bias", action="store_true", default=False,
help='Configuration for the output layer bias.')
group.add_argument("--skip-bias-add", action="store_false", default=True,
help='Configuration for the skip bias.')
group.add_argument('--add-rmsnorm-offset', action='store_true', default=False,
help='RMSNorm unit offset.')
group.add_argument('--geglu', action='store_true', default=False,
help='Geglu activate function.')
group.add_argument('--input-embeds-norm', action='store_true', default=False,
help='input normalization.')
group.add_argument('--gelu-tanh', action='store_true', default=False,
help='Tanh Geglu activate function.')
group.add_argument('--output-logit-softcapping', type=float, help='output logit softcapping.')
group.add_argument('--attn-logit-softcapping', type=float, help='attention logit softcapping.')
group.add_argument('--query-pre-attn-scalar', type=int, help='attention scalar.')
group.add_argument('--interleave-sliding-window', type=int,
help='Window size when use interleave sliding window attention.')
group.add_argument(
'--stage',
default=None,
choices=["sft", "dpo", "orm", "prm", "simpo", "ray_ppo", "ray_online_dpo", "ray_grpo", "trl_ppo"],
help='Determine training mode'
)
group.add_argument('--cut-max-seqlen', action="store_true", help='Determine training mode')
return parser
def _add_rl_args(parser):
group = parser.add_argument_group(title='reinforce learning in dpo')
group.add_argument(
'--dpo-beta',
default=0.1,
type=float,
help='The beta parameter for the DPO loss.'
)
group.add_argument(
'--simpo-beta',
default=2.5,
type=float,
help='The beta parameter for the SimPO loss.'
)
group.add_argument(
'--gamma-beta-ratio',
default=1.4,
type=float,
help='The beta parameter for the SimPO loss.'
)
group.add_argument(
'--dpo-loss-type',
type=str,
default="sigmoid",
choices=["sigmoid", "hinge", "ipo"],
help='The type of DPO loss to use.'
)
group.add_argument(
'--simpo-loss-type',
type=str,
default="sigmoid",
choices=["sigmoid", "hinge", "ipo"],
help='The type of SimPO loss to use.'
)
group.add_argument(
'--simpo-label-smoothing',
default=0.0,
type=float,
help='The robust SimPO label smoothing parameter.'
)
group.add_argument(
'--ref-model',
default=None,
type=str,
help='Path to the reference model used for the PPO or DPO training.'
)
group.add_argument(
'--refer-model-iter',
type=int,
default=1,
help='iteration of the reference model used for the PPO or DPO training.'
)
group.add_argument(
'--dpo-label-smoothing',
default=0.0,
type=float,
help="The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5.",
)
group.add_argument(
'--pref-ftx',
default=0.0,
type=float,
help="The supervised fine-tuning loss coefficient in DPO training.",
)
group.add_argument(
"--is-pairwise-dataset", action='store_true',
help="Whether the dataset is pairwise format that has a chosen sequence and rejected "
"sequence, which usually used in reinforce learning.")
group.add_argument(
'--placeholder-token',
default='ки',
help="A special placeholder token marking the end of each step where the PRM can make predictions.",
)
group.add_argument(
'--reward-tokens',
nargs='+',
type=str,
default=[],
help="The labels represent the correctness of each reasoning step in the entire reasoning process.",
)
group.add_argument(
"--md5-validate", action='store_true',
help="Enable md5 validate."
)
group.add_argument(
"--max-prompt-length", default=512, type=int,
help="Max prompt length in ppo."
)
group.add_argument(
"--num-samples-per-step", default=1, type=int,
help="Number of samples per step in generation."
)
group.add_argument(
"--rollout-batch-size", default=None, type=int,
help="actor rollout batch size."
)
group.add_argument(
"--cliprange-value", default=0.2, type=float,
help="Clip range value."
)
group.add_argument(
"--critic-mini-batch-size", default=1, type=int,
help="critic mini batch size."
)
group.add_argument(
"--critic-update-epochs", default=1, type=int,
help="critic update epochs."
)
group.add_argument(
"--ppo-mini-batch-size", default=1, type=int,
help="ppo mini batch size."
)
group.add_argument(
"--clip-ratio", default=0.2, type=float,
help="ppo loss clip ratio."
)
group.add_argument(
"--entropy-coeff", default=0.001, type=float,
help="entropy coefficient."
)
group.add_argument(
"--ppo-epochs", default=1, type=int,
help="ppo epochs."
)
group.add_argument(
"--shuffle-minibatch", action='store_true',
help="Enable shuffle minibatches in PPO."
)
group.add_argument(
"--do-sample", action='store_true',
help="Enable doing sample in actor generations."
)
group.add_argument(
"--missing-eos-penalty", default=0.0, type=float,
help="eos penalty."
)
group.add_argument(
"--n-samples-per-prompt",
default=1,
type=int,
help="Number of samples per prompt in GRPO."
)
group.add_argument(
'--reward-tokens',
nargs='+',
type=str,
default=[],
help="The labels represent the correctness of each reasoning step in the entire reasoning process.",
)
group.add_argument(
'--reward-model',
default=False,
help="Path to the reference model used for the PPO training."
)
group.add_argument(
"--verifier", action='store_true',
help="Enable verifier in cal scores."
)
group.add_argument(
'--kl-coef',
default=0.3,
type=float,
help="KL coefficient in PPO training.",
)
group.add_argument(
'--gamma',
default=1.0,
type=float,
help="Discount factor in PPO training.",
)
group.add_argument(
'--lam',
default=0.95,
type=float,
help="Lambda value for GAE in PPO training.",
)
group.add_argument(
"--advantage-whiten",
default=True,
help="advantage whiten in GRPO."
)
group.add_argument(
'--dataset-category',
type=str,
default=None,
help='Comma-separated list of dataset category for training,'
' 0 for general problems with no accurate answer, 1 for math problems'
)
group.add_argument(
"--extract-content-for-reward",
default=False,
help="Extract content in answer tag for reward model to judge."
)
return parser
def _add_megatron2_args(parser):
group = parser.add_argument_group(title='run two megatrons at once')
group.add_argument('--num-gpus-for-train', type=int, default=None,
help='num of GPUs for train in two megatrons, training is set to the first group.')
group.add_argument('--num-gpus-for-infer', type=int, default=None,
help='num of GPUs for inference in two megatrons, inference is set to the first group.')
group.add_argument('--inference-tensor-model-parallel-size', type=int, default=1,
help='TP size for inference group in two megatron')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='text generation')
group.add_argument("--task",
nargs='*',
default=None, help='The task id to run.')
group.add_argument("--top-p", type=float, default=0.95, help='Top p sampling.')
group.add_argument("--top-k", type=int, default=50, help='Top k sampling.')
group.add_argument("--temperature", type=float, default=0.7, help='Sampling temperature.')
group.add_argument("--max-length", type=int, default=256, help='Total length of text.')
group.add_argument("--max-new-tokens", type=int, default=128, help='Size of the output generated text.')
group.add_argument('--hf-chat-template', action='store_true', default=False,
help="Using Huggingface chat template")
group.add_argument('--add-eos-token', nargs='+', type=str, default=[],
help="Use additional eos tokens")
group.add_argument('--use-kv-cache', action="store_true", default=False,
help="Use kv cache to accelerate inference")
group.add_argument('--history-turns', type=int, default=3, help='Chat turns of histories.')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.')
group.add_argument('--enable-recompute-layers-per-pp-rank',
action='store_true', default=False,
help='If enabled, --recompute-num-layers will mean the number of '
'layers recomputed in each pp rank. Otherwise it means the number '
'of layers recomputed in each vpp rank.')
group.add_argument('--pre-tockens', type=int, default=65536,
help='pre-tockens is used by Flash attention')
group.add_argument('--next-tockens', type=int, default=0,
help='next-tockens is used by Flash attention')
group.add_argument('--sparse-mode', type=int, default=0,
help='different modes of flash attention mask')
group.add_argument('--shape-order', type=str, default='SBH',
choices=['SBH', 'BSH', 'BSND', 'BNSD'],
help='input shape order used by Flash attention')
group.add_argument('--use-deter-comp',
action='store_true',
default=False,
help='enable deterministic computing for npu')
group.add_argument('--jit-compile', action='store_true', default=False,
help='Setting jit compile mode to True')
group.add_argument('--prompt-type', type=str, default=None,
choices=['default', 'empty', 'trl', 'chatglm2', 'chatglm3', 'chatglm3_system', 'glm4', 'chatml', 'bailing_mini',
'chatml_de', 'qwen', 'qwen_r1', "qwen_math_r1", 'llama3', 'llama2', 'mistral', 'mixtral', 'gemma', 'alpaca',
'deepseek2', 'deepseek2-lite', 'minicpm3', 'cpm', 'baichuan2', 'deepseek3', 'intern2', 'hunyuan', 'qwen3'],
help='Which template to use for constructing prompts in training/inference.' 'e.g., "qwen"')
group.add_argument('--prompt-type-path', type=str, default=TEMPLATES_DIR,
help='Path to the json file of templates.')
group.add_argument('--pad-to-multiple-of', type=int, default=8,
help='Used for Padding multiple in finetune. The default is 8.')
group.add_argument('--scale-emb', type=float, default=None,
help='scale embed tokens')
group.add_argument('--dim-model-base', type=float, default=None,
help='dim-model-base')
group.add_argument('--no-cut-token', action='store_true', default=False,
help='Used for not cut token in finetune.')
group.add_argument('--scale-depth', type=float, default=None,
help='scale-depth')
group.add_argument('--swap-attention', action='store_true', default=False,
help='switch to open swap-attention feature.'
'The default is False.')
group.add_argument('--swap-modules', type=str, default=None,
help='Swap modules for model. Should be used together with "--swap-attention."')
group.add_argument('--load-checkpoint-loosely', action='store_true', default=False,
help='Enable loading checkpoint not strictly.')
group.add_argument('--no-post-layer-norm', action='store_true', default=False,
help='Disable final layer norm.')
group.add_argument('--return-document-ids', action='store_true', default=False,
help='Return document ids when get batch.')
group.add_argument('--reset-attention-mask', action='store_true', default=False,
help='Return document ids when get batch.')
group.add_argument('--swap-optimizer', action='store_true', default=False,
help='swap optimizer to cpu.')
group.add_argument('--swap-optimizer-times', type=int, default=16,
help='Each swap will be moved (len(shard_fp32_from_float16) // swap_optimizer_times) elements')
return parser
def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed')
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
help='Local rank passed from distributed launcher for torch2.x.')
group.add_argument('--distributed-timeout-minutes', type=int, default=30,
help='Timeout minutes for torch.distributed.')
return parser
def _add_ndmm_args(parser):
group = parser.add_argument_group(title='ndmm')
group.add_argument('--use-nd-matmul', action='store_true', default=False,
help='use use-nd-matmul to replace megatron-style tensor parallel')
group.add_argument('--nd1-dim1-size', type=int, default=1,
help='Dim1 of the first nd matmul when use-3d-matmul is True')
group.add_argument('--nd2-dim1-size', type=int, default=1,
help='Dim1 of the second nd matmul when use-3d-matmul is True')
return parser
def _add_2d_tp_args(parser):
group = parser.add_argument_group(title='2d-tp')
group.add_argument('--tp-2d', action='store_true', default=False,
help='use use-2d-tp to replace megatron-style tensor parallel')
group.add_argument('--tp-x', type=int, default=1,
help='the fist dim tensor parallel size for Linear')
group.add_argument('--tp-y', type=int, default=1,
help='the second dim tensor parallel size for Linear')
group.add_argument('--enable-overlap-ag-with-matmul', action='store_true', default=False,
help='use enable-overlap-ag-with-matmul to overlap all-gather with matmul')
group.add_argument('--enable-overlap-matmul-with-rs', action='store_true', default=False,
help='use enable-overlap-matmul-with-rs to overlap matmul with reduce-scatter')
group.add_argument('--enable-backward-overlap-ag-with-matmul', action='store_true', default=False,
help='use enable-backward-overlap-ag-with-matmul to overlap all-gather with matmul in backward')
return parser
def add_parser_argument_choices_value(parser, argument_name, value):
if parser._actions:
for action in parser._actions:
if isinstance(action, argparse._ArgumentGroup):
add_parser_argument_choices_value(action, argument_name)
elif isinstance(action, argparse.Action) and argument_name in action.option_strings:
action.choices.append(value)
def _add_alibi_args(parser):
add_parser_argument_choices_value(parser, "--position-embedding-type", 'alibi')
group = parser.add_argument_group(title='alibi')
group.add_argument('--square-alibi-mask',
action='store_true',
default=False,
help='attention mask of alibi is squared')
group.add_argument('--fill-neg-inf',
action='store_true',
default=False,
help='fill alibi with negative inf')
return parser
def _add_dataset_args(parser):
group = parser.add_argument_group(title='dataset_args')
group.add_argument('--no-shared-storage',
action='store_true',
default=False,
help='if no shared storage, set it'
)
group.add_argument('--dataset-additional-keys',
nargs='*',
default=[],
help='Additional keys need to be add from dataset.'
)
return parser
def _add_hccl_group_buffer_args(parser):
group = parser.add_argument_group(title='hccl-group-buffer')
group.add_argument('--hccl-group-buffer', type=str, default=None,
help='the hccl buffer for group')
return parser
def _validate_create_attention_mask_in_dataloader(args):
args.create_attention_mask_in_dataloader = False
reset_data = args.reset_attention_mask
alibi_without_flash_attn = args.position_embedding_type == 'alibi' and not args.use_flash_attn
if reset_data or alibi_without_flash_attn or args.tokenizer_padding_side == "left":
args.create_attention_mask_in_dataloader = True
if reset_data and args.attention_mask_type == 'causal':
args.create_attention_mask_in_dataloader = False
print_rank0_by_args(args, f"[INFO] Setting args.create_attention_mask_in_dataloader to {args.create_attention_mask_in_dataloader} "
f"since reset_data={reset_data} or alibi_without_flash_attn={alibi_without_flash_attn} or "
f"args.tokenizer_padding_side={args.tokenizer_padding_side}")
if not args.reset_position_ids and args.neat_pack:
raise ValueError("Require set `--reset-position-ids` when `--neat-pack` is set.")
if args.context_parallel_size > 1 and args.reset_attention_mask and args.attention_mask_type == 'causal':
if args.context_parallel_algo != 'megatron_cp_algo':
raise ValueError('accelerated eod reset mode only support ring attention')
def _validate_position_embedding(args):
"""
validate position embedding arguments.
"""
if args.use_glm_rope and args.use_fused_rotary_pos_emb:
raise AssertionError('Fused rotary embedding is not supported in glm rope.')
if args.position_embedding_type == 'alibi' and args.sliding_window is not None:
raise AssertionError('Sliding Window Attention is forbidden when use alibi.')
if args.tokenizer_padding_side == 'left' and args.position_embedding_type == 'alibi':
raise AssertionError('Alibi is not support tokenizer-padding-side left now.')
if not args.use_fused_rmsnorm:
if args.swap_attention and len(args.lora_target_modules) != 0:
raise AssertionError('When not use_fused_rmsnorm, swap_attention cannot be used in lora fune-tuning.')
def _validate_recompute_args(args):
"""
validate re-computation arguments.
"""
enable_pp_vpp = args.num_layers_per_virtual_pipeline_stage
enable_vanilla_recomputation = args.recompute_granularity is not None and args.recompute_method == 'block'
enable_swap = args.swap_attention
enable_recompute_activation = args.recompute_activation_function
enable_recomputation = enable_vanilla_recomputation or enable_swap or enable_recompute_activation
if args.enable_recompute_layers_per_pp_rank and not (enable_pp_vpp and enable_recomputation):
raise AssertionError("enable-recompute-layers-per-pp-rank should be works with pipeline and virtual pipeline, when enabling re-computation.")
if args.recompute_activation_function:
if args.recompute_method == "uniform":
raise AssertionError('uniform recomputation is not compatible with activation function recomputation.')
if args.recompute_granularity == "selective":
raise AssertionError('--recompute-activation-function is not compatible with selective recomputation.')
if args.recompute_norm:
if args.recompute_method == "uniform":
raise AssertionError('uniform recomputation is not compatible with norm recomputation.')
if args.recompute_granularity == "selective":
raise AssertionError('--recompute-norm is not compatible with selective recomputation')
if not args.use_mcore_models:
raise AssertionError('--recompute-norm is only supported with mcore models')
if args.swap_attention and args.swap_modules is None:
if args.use_mcore_models:
args.swap_modules = "input_layernorm,self_attention,pre_cross_attn_layernorm"
else:
args.swap_modules = "input_norm,self_attention,post_attention_norm"
def _validate_instruction_finetune(args):
if args.variable_seq_lengths:
if args.context_parallel_size > 1:
raise AssertionError('Context parallelism is forbidden when use variable seq lengths.')
if args.num_experts is not None and args.moe_token_dispatcher_type == "allgather":
raise AssertionError('moe_token_dispatcher_type "allgather" is forbidden when use variable seq lengths. you can choose "alltoall_seq"')
def _validate_train_mode_conflict(args):
if args.stage == 'sft':
if args.variable_seq_lengths:
if args.neat_pack or args.reset_position_ids:
raise AssertionError('pack mode is forbidden when using variable seq lengths.')
def _validate_inference_args(args):
if args.prompt_type is not None and hasattr(args, "hf_chat_template") and args.hf_chat_template:
raise AssertionError('Prompt-type is forbidden when use huggingface chat template.')
if hasattr(args, "history_turns") and args.history_turns < 0:
raise AssertionError('History turns of chat must greater than 0.')
def _validate_evaluation_args(args):
if args.prompt_type is not None and hasattr(args, "task") and (args.task == "mmlu" or args.task == "ceval"):
train_dir = os.path.join(os.path.dirname(args.task_data_path), "dev")
if not os.path.isdir(train_dir) or not os.path.isdir(args.task_data_path):
raise ValueError(f"Test and dev directory must exists when specify prompt_type in evaluation")
def _validate_moe_args(args):
if not args.use_mcore_models and args.num_experts and args.num_experts > 1:
raise ValueError(f'MOE is not supported in legacy model. Please activate `--use-mcore-models` to enable moe features.')
if args.moe_alltoall_overlap_comm or args.moe_allgather_overlap_comm:
if not args.moe_permutation_async_comm or not args.moe_grouped_gemm:
raise AssertionError(
'`--moe-alltoall-overlap-comm` or `--moe-allgather-overlap-comm` only support with `--moe-permutation-async-comm` and `--moe-grouped-gemm`.')
if args.moe_alltoall_overlap_comm and not args.moe_token_dispatcher_type == 'alltoall_seq':
raise AssertionError('`--moe-alltoall-overlap-comm` only support with `--moe-token-dispatcher-type alltoall_seq`.')
if not args.moe_tp_extend_ep and args.moe_alltoall_overlap_comm and args.tensor_model_parallel_size > 1:
raise AssertionError(
'`--moe-alltoall-overlap-comm` do not support tp for now. only support with moe_tp_extend_ep when tp > 1.')
if args.moe_zero_memory_num_layers is not None:
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
if args.moe_zero_memory_num_layers < 0 or args.moe_zero_memory_num_layers > num_layers_per_pipeline_stage:
raise AssertionError('`--moe-zero-memory-num-layers` must be between 0 and num layers per pipeline stage')
if args.moe_zero_memory == "disable":
raise AssertionError('`--moe-zero-memory` must be enabled when using `--moe-zero-memory-num-layers`')
if args.moe_zero_memory != "disable" and not (args.moe_alltoall_overlap_comm or args.moe_fb_overlap):
raise AssertionError('`--moe-zero-memory` only supports `--moe-alltoall-overlap-comm` and `--moe-fb-overlap` for now.')
if args.moe_zero_memory != "disable" and args.recompute_method is not None:
raise AssertionError('`--moe-zero-memory` does not support full recomputation for now.')
if args.moe_allgather_overlap_comm and not args.moe_token_dispatcher_type == 'allgather':
raise AssertionError('`--moe-allgather-overlap-comm` only support with `--moe-token-dispatcher-type allgather`.')
if args.moe_allgather_overlap_comm and not args.tensor_model_parallel_size > 1 and not args.expert_model_parallel_size > 1:
raise AssertionError('`--moe_allgather_overlap_comm` requires enabling tp or ep.')
if args.shared_expert_gate and args.gradient_accumulation_fusion:
raise AssertionError('args.shared_expert_gate does not support gradient_accumulation_fusion.')
if args.moe_tp_extend_ep and args.moe_token_dispatcher_type == 'allgather':
raise AssertionError('moe_tp_extend_ep is not supported when moe_token_dispatcher_type = allgather')
def _validate_mla(args):
if args.multi_latent_attention:
if args.padded_base_length < 1:
raise AssertionError('The value of padded_base_length cannot be less than 1.')
if args.mla_up_proj_tp_overlap:
if not args.mla_mm_split:
raise ValueError('--mla-up-proj-tp-overlap can only be used with mla-mm-split by now')
if not args.sequence_parallel:
raise ValueError('--mla-up-proj-tp-overlap should be used with sequence parallel')
if args.recompute_mla_up_proj:
if not args.mla_up_proj_tp_overlap:
raise ValueError('--recompute-mla-up-proj can only be used with --mla-up-proj-tp-overlap')
if args.mla_zero_memory:
raise ValueError('--recompute-mla-up-proj is incompatible with --mla-zero-memory')
if args.mla_swap_core_attn_out:
if args.schedules_method != "dualpipev":
raise AssertionError('--mla-swap-core-attn-out can only be used with dualpipev by now.')
if not args.moe_fb_overlap:
raise AssertionError('--mla-swap-core-attn-out can only be used with --moe-fb-overlap by now.')
def _validate_yarn(args):
if args.rope_scaling_type == "yarn":
if args.rope_scaling_original_max_position_embeddings is None:
raise AssertionError('The parameter rope_scaling_original_max_position_embeddings should be set '
'when use yarn.')
if args.multi_head_latent_attention and not args.rope_scaling_factor:
raise AssertionError('The parameter rope_scaling_factor should be set when use yarn.')
if args.max_position_embeddings / args.rope_scaling_original_max_position_embeddings != args.rope_scaling_factor:
raise AssertionError('The parameter rope_scaling_factor should be equal to max_position_embeddings'
'divided by rope_scaling_original_max_position_embeddings.')
def _validate_transformer_block_build_layers(args):
if args.num_experts is None:
if args.first_k_dense_replace is not None or args.moe_layer_freq is not None:
raise AssertionError('First-k-dense-replace and moe-layer-freq must be None when not using MoEs')
else:
if args.first_k_dense_replace and args.num_layers <= args.first_k_dense_replace:
raise AssertionError('Num-layer ({}) must be greater than first-k-dense-replace ({}) when first-k-dense-replace is set.'.format(args.num_layers,
args.first_k_dense_replace))
if args.first_k_dense_replace and args.pipeline_model_parallel_size > 1:
if args.first_k_dense_replace >= args.num_layers // args.pipeline_model_parallel_size:
raise AssertionError('When using first-k-dense-replace, it is not allowed for all layers within a pp stage to be dense layers.')
if args.num_experts is not None and args.use_ascend_mc2 and args.moe_grouped_gemm:
raise AssertionError('Moe Grouped Gemm is not supported with mc2 in MOE model.')
if args.num_layer_list:
if len(args.num_layer_list.split(',')) != args.pipeline_model_parallel_size:
raise ValueError("len(args.num_layer_list) != args.pipeline_model_parallel_size")
if not args.pipeline_model_parallel_size > 1:
raise ValueError("Dynamic pipeline model should work with pipeline parallel.")
if args.num_layers_per_virtual_pipeline_stage or args.noop_layers:
raise ValueError("Dynamic pipeline model is not support work with virtual pipeline or noop layers.")
if args.use_ascend_mc2 and args.use_ascend_coc:
raise AssertionError('--mc2 and coc can not be used together')
def get_layer_offset(pp_size, num_layer_list):
"""
Get layer number offset for pp stage. global_layer_number = local_layer_number + layer_number_offset
For instance, num-layer-list=1,3,3,1,
(1,123,123,1) + (0,1,4,7) = (1,234,567,8)
For each pp_stage, we have layer_number_offset = prefix_sum[pp_stage + 1]
"""
prefix_sum = [0] * (pp_size + 1)
for index, num_layers in enumerate(num_layer_list):
prefix_sum[index + 1] = prefix_sum[index] + num_layers
return prefix_sum
def _validate_output_layer_slice_num(args):
if args.output_layer_slice_num < 1:
raise AssertionError('Output_layer_slice_num must be greater than 0.')
elif args.output_layer_slice_num > 1:
if args.tensor_model_parallel_size > 1:
raise AssertionError('When output_layer_slice_num is greater than 1, only support TP size is 1.')
if (args.padded_vocab_size is not None) and (args.padded_vocab_size % args.output_layer_slice_num != 0):
raise AssertionError('Output_layer_slice_num needs to be divisible by padded_vocab_size.')
elif (args.vocab_size is not None) and (args.vocab_size % args.output_layer_slice_num != 0):
raise AssertionError('Output_layer_slice_num needs to be divisible by vocab_size.')
if args.gradient_accumulation_fusion:
args.gradient_accumulation_fusion = False
print_rank0_by_args(f"gradient_accumulation_fusion would be set to {args.gradient_accumulation_fusion} "
f"since args.output_layer_slice_num > 1")
def core_transformer_config_from_args_wrapper(fn):
@wraps(fn)
def wrapper(args, config_class=None):
config = fn(args, config_class)
if args.pipeline_model_parallel_size == 2 and args.num_layers_per_virtual_pipeline_stage is not None:
config.batch_p2p_comm = False
if args.moe_expert_capacity_factor:
config.moe_expert_capacity_factor = args.moe_expert_capacity_factor
config.moe_pad_expert_input_to_capacity = args.moe_pad_expert_input_to_capacity
config.moe_token_drop_policy = args.moe_token_drop_policy
if args.num_layer_list:
config.num_layer_list = list(map(int, args.num_layer_list.split(',')))
config.layer_offset = get_layer_offset(args.pipeline_model_parallel_size, config.num_layer_list)
if config.layer_offset[args.pipeline_model_parallel_size] != args.num_layers:
raise ValueError(f"Incorrect num_layer_list config since its sum({config.layer_offset[args.pipeline_model_parallel_size]} is unequal to total num layers({args.num_layers}).")
else:
config.num_layer_list = None
return config
return wrapper
def _validate_optimizer(args):
if args.reuse_fp32_param and not args.bf16:
raise AssertionError('--reuse-fp32-param only support for `bf16`')
if args.reuse_fp32_param and args.swap_optimizer:
raise AssertionError('--swap-optimizer dose not support `--reuse-fp32-param`')
def _store_variables(args):
"""
To bypass megatron validation, we store variables and restore them shortly afterward.
"""
variable_dict = dict()
variable_dict["variable_seq_lengths"] = args.variable_seq_lengths
variable_dict["spec"] = args.spec
args.spec = None
if args.pipeline_model_parallel_size == 2 and args.num_layers_per_virtual_pipeline_stage is not None:
variable_dict["num_layers_per_virtual_pipeline_stage"] = args.num_layers_per_virtual_pipeline_stage
variable_dict["overlap_p2p_comm"] = args.overlap_p2p_comm
args.num_layers_per_virtual_pipeline_stage = None
args.overlap_p2p_comm = None
if args.o2_gradient:
args.accumulate_allreduce_grads_in_fp32 = True
if args.mtp_num_layers is not None and args.mtp_num_layers >= 1 and args.context_parallel_size > 1:
variable_dict["mtp_num_layers"] = args.mtp_num_layers
args.mtp_num_layers = 0
if hasattr(args, "role") and args.role == "actor_rollout":
variable_dict["true_world_size"] = args.world_size
args.world_size = args.num_gpus_for_train
if not args.data_path:
args.data_path = True
args.mock_data = 0
return variable_dict
def _restore_variables(args, variable_dict):
args.variable_seq_lengths = variable_dict["variable_seq_lengths"]
if variable_dict.get("num_layers_per_virtual_pipeline_stage") and args.pipeline_model_parallel_size == 2:
args.num_layers_per_virtual_pipeline_stage = variable_dict["num_layers_per_virtual_pipeline_stage"]
args.overlap_p2p_comm = variable_dict["overlap_p2p_comm"]
if args.o2_gradient:
args.accumulate_allreduce_grads_in_fp32 = False
if variable_dict.get("mtp_num_layers"):
args.mtp_num_layers = variable_dict["mtp_num_layers"]
if hasattr(args, "role") and args.role == "actor_rollout":
args.world_size = variable_dict["true_world_size"]
if args.rank >= args.num_gpus_for_train:
args.data_parallel_size = args.num_gpus_for_infer // args.tensor_model_parallel_size
args.micro_batch_size = args.num_samples_per_step
args.spec = variable_dict["spec"]
def _add_dummy_args(args):
"""
For arguments in mindspeed-core which is currently unsupported in mindspeed-llm.
"""
args.adaptive_recompute_device_swap = False
args.adaptive_recompute_device_size = -1
args.adaptive_recompute_profiling_step = 10
args.recompute_in_bubble = False
args.use_nanopipe = False
args.moe_without_activation = False
args.ampipe_degree = 0
args.hccl_group_buffer_adaptive = False
args.moe_bmm_mc2 = False
args.moe_hierarchical_alltoallv = False
args.moe_experts_pipeline_degree = 0
args.context_parallel_kv_cache_policy = None
args.context_parallel_cache_interval = 0
args.use_ulysses_allgather_kv = False
args.use_pipe_experts = False
args.megatron_cp_in_bnsd = False
args.use_fusion_attn_v2 = False
args.npu_deterministic = False
args.prof_file = False
args.moe_adaptive_recompute_activation = False
args.op_cal_tflops = False
args.profile_operator = False
args.delay_grad_reduce = False
args.delay_param_gather = False
args.alibi_fusion_attn_type = None
args.moe_adaptive_recompute_activation_scale = 2.0
args.yarn_scaling_factor = 1.0
args.moe_intermediate_size = args.moe_ffn_hidden_size
def _add_dummy_args_v2(args):
"""
For arguments in feature_list which is currently unsupported in mindspeed-llm.
"""
args.unaligned_linear = False
args.embed_layernorm = False
args.enable_share_memory = False
args.return_document_ids = False
args.attention_mask_on_cpu = False
args.output_layer_slice_num = 1
args.use_fused_mlp = False
args.transformer_impl = 'local'
args.ckpt_format = 'torch'
args.use_dist_ckpt = False
def _validate_noop_layer(args):
if isinstance(args.noop_layers, str):
noop_layers = set()
for x in args.noop_layers.split(','):
if int(x) >= args.num_layers or int(x) < 0:
raise AssertionError(f'each element in args.noop_layers({args.noop_layers}) should bigger or equal '
f'to 0 and smaller than args.num_layers({args.num_layers})')
noop_layers.add(int(x))
args.noop_layers = noop_layers
if args.num_layer_list:
print_rank0_by_args("num layer list would be disabled when noop-layer is activated.")
args.num_layer_list = None
def _valid_tp_2d_args(args):
if args.tp_2d:
if args.sequence_parallel:
raise AssertionError('2d tp does not support sequence parallel')
if args.use_fused_rmsnorm:
raise AssertionError('2d tp does not support fused rmsnorm')
if hasattr(args, "use_nanopipe") and args.use_nanopipe:
raise AssertionError('tp-2d does not support nano-pipe')
if hasattr(args, "ampipe_degree") and args.ampipe_degree > 1:
raise AssertionError('tp-2d does not support ampipe')
if hasattr(args, "context_parallel_algo") and args.context_parallel_algo not in ['megatron_cp_algo', 'ulysses_cp_algo']:
raise AssertionError('tp-2d now only support megatron_cp_algo or ulysses_cp_algo')
if hasattr(args, "use_ascend_coc") and args.use_ascend_coc:
raise AssertionError('tp-2d does not support ascend coc')
if args.tensor_model_parallel_size // args.tp_x != args.tp_y:
raise AssertionError('need satisfy tp = tp_x * tp_y')
if args.expert_model_parallel_size > 1:
raise AssertionError('2d tp does not support moe')
def _valid_fa_div_args(args):
if args.mla_fa_divide_qk:
if args.context_parallel_size > 1:
raise AssertionError('MLA FA currently not support CP>1.')
if not args.reset_position_ids:
raise AssertionError('MLA FA currently only support TND.')
def _validate_vpp(args):
"""validate scenario that vpp is enabled when pp=2."""
if args.pipeline_model_parallel_size != 2 or args.num_layers_per_virtual_pipeline_stage is None:
return
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
if num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage != 0:
raise ValueError('number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage')
pp_stage_layers = args.num_layers / args.pipeline_model_parallel_size
if args.num_layers_per_virtual_pipeline_stage and args.num_layers_per_virtual_pipeline_stage >= pp_stage_layers:
raise ValueError("Num of layers in vpp stage should be less than pp stage, "
"please turn down args.num_layers_per_virtual_pipeline_stage.")
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
print_rank0_by_args(args, f'vpp_size would be {args.virtual_pipeline_model_parallel_size} since '
f'num_layers_per_virtual_pipeline_stage is {args.num_layers_per_virtual_pipeline_stage}')
def _validate_recompute_in_advance(args):
args.reduce_recompute_for_last_chunk = False
if args.recompute_in_advance:
args.reduce_recompute_for_last_chunk = True
if args.recompute_method == "uniform":
raise AssertionError('recompute_in_advance does not support uniform recompute_method')
if args.recompute_granularity == 'selective':
raise AssertionError('recompute_in_advance does not support vanilla recompute_activations.')
if not args.recompute_num_layers:
raise AssertionError('recompute_num_layers must be configured when using recompute_in_advance')
if args.pipeline_model_parallel_size <= 1 or args.num_layers_per_virtual_pipeline_stage != 1:
raise AssertionError('recompute_in_advance only support pipelining with interleaving and vpp stage should be 1.')
def _validate_mlp_fusion(args):
if args.use_fused_mlp:
if not args.swiglu:
raise AssertionError(
'use_fused_mlp only support activation func swiglu')
if not args.sequence_parallel and not args.lora_fusion:
raise AssertionError(
'use_fused_mlp only support sequence_parallel with cclora')
def _validate_long_rope(args):
if args.rope_scaling_type == "longrope":
if args.rope_scaling_original_max_position_embeddings is None:
raise AssertionError('The parameter rope_scaling_original_max_position_embeddings should be set '
'when use longrope.')
if args.long_factor is None:
raise AssertionError('The parameter long_factor should be set when use longrope.')
else:
args.long_factor = list(map(float, args.long_factor.split(',')))
if args.short_factor is None:
raise AssertionError('The parameter short_factor should be set when use longrope.')
else:
args.short_factor = list(map(float, args.short_factor.split(',')))
if bool(args.short_mscale) ^ bool(args.long_mscale):
raise AssertionError('The parameter short_mscale and long_mscale must be set at the same time')
def _validate_o2(args):
if args.o2_gradient or args.o2_optimizer:
print_rank0_by_args(args, "[WARNING] Using half precision gradient or optimizer would definitely impact precision.")
if not args.bf16:
raise ValueError("[ERROR] Should only use o2 feature during bf16 mix-precision training.")
if args.o2_gradient:
if args.gradient_accumulation_fusion:
raise ValueError("gradient_accumulation_fusion only works with fp32 gradient.")
def _validate_fused_opts(args):
if args.use_fused_rmsnorm:
if args.normalization != "RMSNorm":
raise AssertionError(
'--use-fused-rmsnorm must enable with '
'--normalization=RMSNorm, but got normalization'
'={}.'.format(args.normalization))
if args.use_fused_swiglu:
if not args.swiglu:
raise AssertionError(
'--use-fused-swiglu must enable with --swiglu, '
'but --swiglu={}.'.format(args.swiglu))
if args.use_fused_rotary_pos_emb:
if args.position_embedding_type != 'rope':
raise AssertionError(
'--use-fused-rotary-pos-emb must enable with'
'--position-embedding-type=rope')
def validate_args_decorator(megatron_validate_args):
@wraps(megatron_validate_args)
def wrapper(args, defaults=None):
if defaults is None:
defaults = {}
variable_dict = _store_variables(args)
if args.num_layer_list:
_num_layer = args.num_layers
args.num_layers = len(args.num_layer_list.split(','))
megatron_validate_args(args, defaults)
args.num_layers = _num_layer
args.encoder_num_layers = args.num_layers
else:
megatron_validate_args(args, defaults)
_restore_variables(args, variable_dict)
args.use_ascend_mc2 = False
args.use_legacy_models = not args.use_mcore_models
_validate_o2(args)
_validate_varlen_fa_args(args)
_validate_cp_args(args)
_validate_transformer_block_build_layers(args)
_validate_vpp(args)
_validate_recompute_args(args)
_validate_recompute_in_advance(args)
_validate_create_attention_mask_in_dataloader(args)
_validate_instruction_finetune(args)
_validate_train_mode_conflict(args)
_validate_position_embedding(args)
_validate_inference_args(args)
_validate_moe_args(args)
_validate_mla(args)
_validate_yarn(args)
_validate_evaluation_args(args)
_validate_output_layer_slice_num(args)
_validate_optimizer(args)
_validate_long_rope(args)
_validate_mlp_fusion(args)
_validate_fused_opts(args)
_validate_noop_layer(args)
_valid_tp_2d_args(args)
_valid_fa_div_args(args)
_add_dummy_args(args)
for feature in FEATURES_LIST:
if (getattr(args, feature.feature_name, None) and feature.optimization_level > 0) or feature.optimization_level == 0:
feature.pre_validate_args(args)
feature.validate_args(args)
feature.post_validate_args(args)
from mindspeed_llm.training.utils import print_args
print_args('MindSpeed-LLM Arguments', args)
return args
return wrapper
def validate_args_v2_decorator(megatron_validate_args):
"""A decorator for megatron arguments validation function."""
@wraps(megatron_validate_args)
def wrapper(args, defaults=None):
if defaults is None:
defaults = {}
MindSpeedFeaturesManager.pre_validate_features_args(args)
args = megatron_validate_args(args, defaults)
MindSpeedFeaturesManager.post_validate_features_args(args=args)
_add_dummy_args_v2(args)
MindSpeedFeaturesManager.validate_features_args(args=args)
from mindspeed_llm.training.utils import print_args
print_args('MindSpeed-LLM Arguments', args)
return args
return wrapper