from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class ReplaceIndexPutFeature(MindSpeedFeature):
def __init__(self):
super().__init__('vocab-parallel', optimization_level=2)
def is_need_apply(self, args):
"""Check the feature is need to apply."""
return self.optimization_level <= args.optimization_level
def register_patches(self, patch_manager, args):
from mindspeed.core.tensor_parallel.vocab_parallel.adaptor import mindspeed_vocab_parallel_embedding_forward
patch_manager.register_patch('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward',
mindspeed_vocab_parallel_embedding_forward)