from typing import List
from mindspeed.features_manager import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
class MindSpeedFeaturesManager:
FEATURES_LIST = []
@classmethod
def set_features_list(cls, features_list: List[MindSpeedFeature]):
"""Set features list"""
cls.FEATURES_LIST[:] = features_list
@classmethod
def apply_features_pre_patches(cls, mindspeed_args):
"""Apply pre patches of all features."""
for feature in cls.FEATURES_LIST:
if feature.is_need_apply(mindspeed_args):
feature.pre_register_patches(MindSpeedPatchesManager, mindspeed_args)
MindSpeedPatchesManager.apply_patches()
@classmethod
def apply_features_patches(cls, mindspeed_args):
"""Apply patches of all features."""
for feature in cls.FEATURES_LIST:
if feature.is_need_apply(mindspeed_args):
feature.register_patches(MindSpeedPatchesManager, mindspeed_args)
MindSpeedPatchesManager.apply_patches()
@classmethod
def register_features_args(cls, parser):
"""Parse arguments of all features."""
for feature in cls.FEATURES_LIST:
feature.register_args(parser)
@classmethod
def pre_validate_features_args(cls, args):
"""Pre-validate arguments of all features. Used to bypass megatron arguments validation.
Example:
pre_validate_features_args(args) # old_x = args.x; args.x = new_x
args = validate_args(args, defaults) # bypass args.x validation
post_validate_features_args(args=args) # args.x = old_x
"""
for feature in cls.FEATURES_LIST:
feature.pre_validate_args(args)
@classmethod
def post_validate_features_args(cls, args):
"""Post-validate arguments of all features. Used to bypass megatron arguments validation.
Example:
pre_validate_features_args(args) # old_x = args.x; args.x = new_x
args = validate_args(args, defaults) # bypass args.x validation
post_validate_features_args(args=args) # args.x = old_x
"""
for feature in cls.FEATURES_LIST:
feature.post_validate_args(args)
@classmethod
def validate_features_args(cls, args):
"""Validate arguments of all features."""
for feature in cls.FEATURES_LIST:
feature.validate_args(args)
@classmethod
def remove_patches(cls):
MindSpeedPatchesManager.remove_patches()