from mindspeed.features_manager.feature import MindSpeedFeature
class AffinityFeature(MindSpeedFeature):
def __init__(self):
super().__init__('affinity', optimization_level=1)
def is_need_apply(self, args):
"""Check the feature is need to apply."""
return self.optimization_level <= args.optimization_level
def register_patches(self, patch_manager, args):
from mindspeed.core.tensor_parallel.cross_entropy import calculate_predicted_logits
patch_manager.register_patch(
'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits',
calculate_predicted_logits)