from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class TflopsCalculateFeature(MindSpeedFeature):
def __init__(self):
super().__init__('op-cal-tflops', 2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--op-cal-tflops', action='store_true', default=False,
help='use for cal mfu and hfu')
def validate_args(self, args):
if args.op_cal_tflops and args.multi_latent_attention:
raise AssertionError("Multi-head latent attention currently does not support op-cal-tflops")
def register_patches(self, patch_manager, args):
from mindspeed.functional.tflops_calculate.adaptor import training_log
from mindspeed.functional.tflops_calculate.tflops_utils import checkpoint_function_backward_wrapper
from mindspeed.functional.tflops_calculate.tflops_utils import train_step_wrapper
if getattr(args, self.feature_name, None):
patch_manager.register_patch('megatron.training.training.training_log', training_log)
patch_manager.register_patch('mindspeed.core.tensor_parallel.random.checkpoint_function_backward', checkpoint_function_backward_wrapper)
patch_manager.register_patch('megatron.training.training.train_step', train_step_wrapper)