from functools import wraps
def process_args_decorator(process_args):
@wraps(process_args)
def wrapper(parser):
parser = process_args(parser)
parser = process_mindspore_args(parser)
return parser
return wrapper
def process_mindspore_args(parser):
"""
Process MindSpore arguments.
:param parser: Argument parser to add arguments to.
:return: Updated argument parser.
"""
parser = _add_moba_args(parser)
parser = _add_communication_overlap_args(parser)
return parser
def _add_communication_overlap_args(parser):
group = parser.add_argument_group(title='overlap_p2p_comm_or_async_log_allreduce_')
group.add_argument('--async-log-allreduce', action='store_true',
help='Transform the AllReduce operation used for transmitting log information into an '
'asynchronous operation to reduce communication overhead. '
'This is useful in cross-DataCenter (DC) training.')
return parser
def _validate_optimizer(args):
if args.reuse_fp32_param and not args.bf16:
raise AssertionError('--reuse-fp32-param only support for `bf16`')
if args.reuse_fp32_param and args.swap_optimizer:
raise AssertionError('--swap-optimizer dose not support `--reuse-fp32-param`')
if args.reuse_fp32_param and not args.use_distributed_optimizer:
raise ValueError(
"When using the --reuse-fp32-param feature, the --use-distributed-optimizer feature must also be enabled.")
def _add_moba_args(parser):
group = parser.add_argument_group(title='moba')
group.add_argument('--use-moba-attn', action='store_true', default=False,
help='use moba attention')
group.add_argument('--moba-chunk-size', type=int, default=64,
help='moba attention chunk size. default: 64')
group.add_argument('--moba-topk', type=int, default=2,
help='moba attention topk')
group.add_argument('--moba-calc-method', type=int, default=1,
help='moba calculation method. 1: naive attention with naive attention operations; 2: use flash'
'attention. default: 1')
return parser