from functools import wraps
from mindspeed.arguments import _add_auto_settings_args
def extra_args_provider_decorator(extra_args_provider):
@wraps(extra_args_provider)
def wrapper(parser):
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = process_args(parser)
return parser
return wrapper
def process_args(parser):
parser.conflict_handler = "resolve"
parser = _add_lora_args(parser)
parser = _add_training_args(parser)
parser = _add_network_size_args(parser)
parser = _add_dummy_optimizer_args(parser)
parser = _add_logging_args(parser)
parser = _add_security_args(parser)
parser = _add_auto_parallel_mm_args(parser)
parser = _add_rlfh_args(parser)
parser = _add_network_args(parser)
parser = _add_data_balance_args(parser)
parser = _add_auto_settings_args(parser)
parser = _add_optim_arguments(parser)
parser = _add_muon_optim_arguments(parser)
parser = _add_text_dynamic_batching_args(parser)
parser = _add_image_mbs_balance_args(parser)
return parser
def _add_lora_args(parser):
group = parser.add_argument_group(title='lora')
group.add_argument('--lora-target-modules', nargs='+', type=str, default=[],
help='Use lora in target modules.')
group.add_argument('--lora-target-parameters', nargs='+', type=str, default=[],
help='Use lora in target parameters.')
group.add_argument('--lora-apply-modules', nargs='+', type=str, default=["all"],
help='Use lora exclude modules')
group.add_argument('--lora-mixed-training', type=bool, default=False,
help='Mixed training for lora and non-lora args')
group.add_argument('--load-base-model', type=str, default=None,
help='Directory containing a base model checkpoint for lora.')
group.add_argument('--lora-dropout', type=float, default=0.0, help="lora dropout rate")
group.add_argument('--lora-r', type=int, default=8,
help='Lora rank.')
group.add_argument('--lora-alpha', type=int, default=16,
help='Lora alpha.')
group.add_argument('--lora-register-forward-hook', nargs='+', type=str,
default=['word_embeddings', 'input_layernorm', 'final_layernorm'],
help='Lora register forward hook.')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--use-deter-comp',
action='store_true',
default=False,
help='Enable deterministic computing for npu')
group.add_argument('--jit-compile',
action='store_true',
default=False,
help='Setting jit compile mode to True')
group.add_argument('--allow-tf32',
action='store_true',
default=False,
help='Use tf32 to train')
group.add_argument('--downcast-to-bf16',
action='store_true',
default=False,
help='whether to downcast model weight from fp32 to bf16 while loading ckpt')
group.add_argument('--allow-internal-format',
action='store_true',
default=False,
help='Use internal format to train')
group.add_argument('--virtual-pipeline-model-parallel-size',
type=int,
default=None,
help='vpp size')
group.add_argument('--encoder-dp-balance',
action='store_true',
default=False,
help='Balance for encoder')
group.add_argument('--recompute-skip-core-attention',
action='store_true',
default=False,
help='Recomputing will skip the Flash attention if True')
group.add_argument('--recompute-num-layers-skip-core-attention',
type=int,
default=0)
group.add_argument('--hetero-parallel',
action='store_true',
default=False,
help='apply different parallelism to different models')
group.add_argument('--hetero-encoder-mbs-scale',
type=int,
default=1,
help='Adjust ViT/audio encoder MBS to x-times LLM decoder MBS (x = this param)')
group.add_argument('--calculate-per-sample-loss',
action='store_true',
default=False,
help=('Calculate the loss at the sample level: perform token-level mean '
'within each sample, and sequence-level mean across samples.'))
group.add_argument('--calculate-square-loss',
action='store_true',
default=False,
help=('Calculate the loss.'))
group.add_argument('--calculate-token-loss',
action='store_true',
default=False,
help=('Calculate the loss.'))
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd', 'muon'],
help='Optimizer function')
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network_size_args')
group.add_argument('--padded-vocab-size',
type=int,
default=None,
help='set padded vocab size')
return parser
def _add_dummy_optimizer_args(parser):
group = parser.add_argument_group(title='dummy optimizer args')
group.add_argument('--enable-dummy-optimizer',
action='store_true',
default=False,
help='enable dummy optimizer')
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-tps',
action='store_true',
default=False,
help='calculate and log average tokens per sample')
return parser
def _add_security_args(parser):
group = parser.add_argument_group(title='security configuration')
group.add_argument('--trust-remote-code',
action='store_true',
default=False,
help='Whether or not to allow for custom models defined on the Hub in their own modeling files.')
return parser
def _add_auto_parallel_mm_args(parser):
group = parser.add_argument_group(title='auto_parallel_mm')
group.add_argument('--profile-subgraph-seg', action='store_true', default=False, help='model segmentation')
group.add_argument('--profile-stage', type=int, default=None, help='model profile stage')
group.add_argument('--simulated-nnodes', type=int, default=None, help='the simulated number of node in the cluster')
group.add_argument('--simulated-nproc-per-node', type=int, default=None, help='the simulated number of NPU on each node')
return parser
def _add_rlfh_args(parser):
group = parser.add_argument_group(title='dpo')
group.add_argument(
'--dpo-beta',
type=float,
default=0.1,
help="The beta parameter for the DPO loss"
)
group.add_argument(
'--dpo-loss-type',
default="sigmoid",
choices=["sigmoid"],
help="The type of DPO loss to use"
)
group.add_argument(
"--dpo-label-smoothing",
type=float,
default=0.0,
help="The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."
)
group.add_argument(
'--ref-model',
default=None,
type=str,
help='Path to the reference model used for the PPO or DPO training.'
)
group.add_argument(
'--pref-ftx',
default=0.0,
type=float,
help="The supervised fine-tuning loss coefficient in DPO training.",
)
return parser
def _add_network_args(parser):
group = parser.add_argument_group(title='network')
group.add_argument(
'--stage',
default=None,
choices=["ray_grpo"],
help='Determine training mode'
)
return parser
def _add_data_balance_args(parser):
group = parser.add_argument_group(title="GBS_data_balance")
group.add_argument("--use-data-balance",
action='store_true',
default=False,
help="Enable data balance")
group.add_argument("--data_balance_sorting_algo", type=str, default="post_global_balancing_greedy_without_pad",
help="data balance sorting algorithm:"
"post_global_balancing_greedy_without_pad: a greedy post global balancing algorithm without padding")
return parser
def _add_image_mbs_balance_args(parser):
group = parser.add_argument_group(title="MBS_data_balance")
group.add_argument("--use-image-mbs-data-balance",
action='store_true',
default=False,
help="Enable data balance")
group.add_argument("--mbs_data_balance_sorting_algo", type=str, default="post_mbs_balancing_greedy_without_pad",
help="data balance sorting algorithm:"
"post_mbs_balancing_greedy_without_pad: a greedy post local balancing algorithm without padding")
return parser
def _add_text_dynamic_batching_args(parser):
group = parser.add_argument_group(title="text_dynamic_batching")
group.add_argument("--use-txt-dynamic-batching",
action='store_true',
default=False,
help="Enable dynamic batching for LLM")
group.add_argument("--max-seq-len", type=int, default=2048,
help="max sequence length of concatenated text for each micro batch")
group.add_argument("--dynamic-batch-buffer-size", type=int, default=200,
help="the size of dynamic batching buffer")
return parser
def _add_optim_arguments(parser):
group = parser.add_argument_group(title='optimization_filter')
group.add_argument(
'--weight-decay-exclude-modules',
nargs='+',
type=str,
default=[],
help='Keywords in parameter names to exclude from weight decay. Empty list disables this feature.'
)
group.add_argument(
'--lr-scale-modules',
nargs='+',
type=str,
default=[],
help='Keywords in parameter names to apply learning rate scaling. Empty list disables this feature.'
)
group.add_argument(
'--lr-mult',
type=float,
default=1.0,
help='Learning rate multiplier for parameters matching scale-lr-keywords. '
)
return parser
def _add_muon_optim_arguments(parser):
group = parser.add_argument_group(title='muon_optimizer')
group.add_argument(
'--matched-adamw-rms',
type=float,
default=0.2,
help='Matched AdamW RMS value for Muon optimizer. '
'Controls how closely Muon matches AdamW update magnitude. '
'Typical range: 0.1-0.4. Default: 0.2'
)
group.add_argument(
'--muon-momentum',
type=float,
default=0.95,
help='Momentum coefficient for Muon internal SGD. '
'Higher values give more weight to previous updates. '
'Range: 0.0-1.0. Default: 0.95'
)
group.add_argument(
'--ns-steps',
type=int,
default=5,
help='Number of Newton-Schulz iterations for orthogonalization. '
'More steps give better orthogonalization but slower training. '
'Range: 1-10. Default: 5'
)
return parser