import os
import argparse
from pathlib import Path
from functools import wraps
from mindspeed.features_manager.features_manager import MindSpeedFeaturesManager
from mindspeed_llm.features_manager import FEATURES_LIST
cur_file_dir = Path(__file__).absolute().parent
TEMPLATES_DIR = os.path.join(cur_file_dir.parent.parent, "configs/finetune/templates.json")
def extra_args_provider_decorator(extra_args_provider):
"""
Decorator for extra arguments provider to add MindSpeed-LLM specific arguments.
This decorator wraps the extra arguments provider function to inject
MindSpeed-LLM feature arguments into the argument parser.
Args:
extra_args_provider: The original extra arguments provider function.
Returns:
Callable: Wrapped function that adds MindSpeed-LLM arguments.
The wrapper:
1. Calls the original provider if it exists
2. Adds MindSpeed-LLM v2 arguments via process_args_v2
"""
@wraps(extra_args_provider)
def wrapper(parser):
if extra_args_provider is not None:
parser = extra_args_provider(parser)
parser = process_args_v2(parser)
return parser
return wrapper
def parse_args_decorator(parse_args):
"""
Decorator for parse_args to inject MindSpeed-LLM argument processing.
This decorator wraps the argument parsing function to ensure MindSpeed-LLM
specific arguments are properly processed.
Args:
parse_args: The original parse_args function.
Returns:
Callable: Wrapped function that processes MindSpeed-LLM arguments.
"""
@wraps(parse_args)
def wrapper(extra_args_provider=None, ignore_unknown_args=False):
decorated_provider = extra_args_provider_decorator(extra_args_provider)
return parse_args(decorated_provider, ignore_unknown_args)
return wrapper
def process_args_v2(parser):
"""
Process and register MindSpeed-LLM v2 feature arguments.
This function registers all MindSpeed-LLM specific feature arguments
using the MindSpeedFeaturesManager.
Args:
parser: Argument parser to add arguments to.
Returns:
argparse.ArgumentParser: Parser with MindSpeed-LLM arguments added.
"""
MindSpeedFeaturesManager.register_features_args(parser)
return parser
def get_layer_offset(pp_size, num_layer_list):
"""
Get layer number offset for pp stage. global_layer_number = local_layer_number + layer_number_offset
For instance, num-layer-list=1,3,3,1,
(1,123,123,1) + (0,1,4,7) = (1,234,567,8)
For each pp_stage, we have layer_number_offset = prefix_sum[pp_stage + 1]
"""
prefix_sum = [0] * (pp_size + 1)
for index, num_layers in enumerate(num_layer_list):
prefix_sum[index + 1] = prefix_sum[index] + num_layers
return prefix_sum
def core_transformer_config_from_args_wrapper(fn):
"""
Wrapper for creating TransformerConfig from arguments with MindSpeed-LLM extensions.
This decorator wraps the config creation function to add MindSpeed-LLM specific
configurations including MoE settings and custom layer distribution.
Args:
fn: The original config creation function.
Returns:
Callable: Wrapped function that creates config with MindSpeed-LLM extensions.
The wrapper adds:
- batch_p2p_comm optimization for PP2VPP
- MoE expert capacity factor settings
- Custom layer distribution via num_layer_list
"""
@wraps(fn)
def wrapper(args, config_class=None):
config = fn(args, config_class)
if args.pipeline_model_parallel_size == 2 and args.num_layers_per_virtual_pipeline_stage is not None:
config.batch_p2p_comm = False
if args.moe_expert_capacity_factor:
config.moe_expert_capacity_factor = args.moe_expert_capacity_factor
config.moe_pad_expert_input_to_capacity = args.moe_pad_expert_input_to_capacity
config.moe_token_drop_policy = args.moe_token_drop_policy
if args.num_layer_list:
config.num_layer_list = list(map(int, args.num_layer_list.split(',')))
config.layer_offset = get_layer_offset(args.pipeline_model_parallel_size, config.num_layer_list)
if config.layer_offset[args.pipeline_model_parallel_size] != args.num_layers:
raise ValueError(f"Incorrect num_layer_list config since its sum({config.layer_offset[args.pipeline_model_parallel_size]} is unequal to total num layers({args.num_layers}).")
else:
config.num_layer_list = None
return config
return wrapper
def _add_dummy_args_v2(args):
"""
Add dummy arguments for features currently unsupported in MindSpeed-LLM.
This function initializes unsupported feature arguments to False or default values
to maintain compatibility with the broader codebase.
Args:
args: Arguments namespace to add dummy arguments to.
Note:
These arguments exist in the feature list but are not yet supported
in MindSpeed-LLM implementation.
"""
args.unaligned_linear = False
args.embed_layernorm = False
args.enable_share_memory = False
args.return_document_ids = False
args.attention_mask_on_cpu = False
args.output_layer_slice_num = 1
args.use_fused_mlp = False
def validate_args_v2_decorator(megatron_validate_args):
"""
Decorator for Megatron arguments validation with MindSpeed-LLM extensions.
This decorator wraps the Megatron validation function to add MindSpeed-LLM
specific argument validation and feature management.
Args:
megatron_validate_args: The original Megatron validation function.
Returns:
Callable: Wrapped validation function with MindSpeed-LLM validation.
The validation process:
1. Pre-validate MindSpeed-LLM feature arguments
2. Call Megatron validation
3. Post-validate MindSpeed-LLM feature arguments
4. Add dummy arguments for unsupported features
5. Validate all MindSpeed-LLM arguments
6. Print MindSpeed-LLM arguments
"""
@wraps(megatron_validate_args)
def wrapper(args, defaults=None):
if defaults is None:
defaults = {}
MindSpeedFeaturesManager.pre_validate_features_args(args)
args = megatron_validate_args(args, defaults)
MindSpeedFeaturesManager.post_validate_features_args(args=args)
_add_dummy_args_v2(args)
MindSpeedFeaturesManager.validate_features_args(args=args)
from mindspeed_llm.training.utils import print_args
print_args('MindSpeed-LLM Arguments', args)
return args
return wrapper