from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MindSporePatchFeature(MindSpeedFeature):
def __init__(self):
super().__init__('mindspore-patch', optimization_level=0)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--ai-framework', type=str, default='pytorch', help='support pytorch and mindspore')
def register_patches(self, patch_manager, args):
if not hasattr(args, "ai_framework") or args.ai_framework != "mindspore" or args.optimization_level < 0:
return
from mindspeed.mindspore.mindspore_adaptor import mindspore_adaptation
mindspore_adaptation(patch_manager, args)