import argparse
from contextvars import ContextVar
from logging import getLogger
LOG = getLogger(__name__)
_full_args: ContextVar = ContextVar('full_args', default=None)
_MINDSPEED_ARGS = None
def add_args(args, key, value):
if key is not None:
key = key[2:].replace('-', '_')
if value is None:
value = True
elif len(value) == 1:
value = value[0]
args[key] = value
def parser_unknown_args(args, unknown):
i = 0
key = value = None
while i < len(unknown):
if unknown[i].startswith('-'):
add_args(args, key, value)
splits = unknown[i].split('=', maxsplit=1)
key, value = (unknown[i], None) if len(splits) == 1 else (splits[0], [splits[1]])
else:
if value is None:
value = [unknown[i]]
else:
value.append(unknown[i])
i += 1
add_args(args, key, value)
return args
def get_mindspeed_args(get_defaults=False):
global _MINDSPEED_ARGS
if not get_defaults and _MINDSPEED_ARGS is not None:
return _MINDSPEED_ARGS
parser = argparse.ArgumentParser(description='MegatronAdaptor Arguments', allow_abbrev=False)
from megatron_adaptor.features_manager.features_manager import FeaturesManager
FeaturesManager.register_features_args(parser)
args, unknown = parser.parse_known_args() if not get_defaults else parser.parse_known_args([])
args = parser_unknown_args(vars(args), unknown)
args = argparse.Namespace(**args)
if not get_defaults:
_MINDSPEED_ARGS = args
return args
def get_full_args() -> argparse.Namespace:
args = _full_args.get()
if args is not None:
return args
try:
from megatron.training.global_vars import get_args, _GLOBAL_ARGS
full_args = get_args()
if full_args is None:
full_args = get_mindspeed_args()
except ImportError:
LOG.warning('Failed from megatron.training import get_args, use megatron_adaptor arguments.')
full_args = get_mindspeed_args()
except AssertionError as e:
if 'args is not initialized.' in str(e):
full_args = get_mindspeed_args()
else:
raise e
return full_args
def set_full_args(args: argparse.Namespace):
_full_args.set(args)