from mindspeed.features_manager.feature import MindSpeedFeature
class NPUDataDumpFeature(MindSpeedFeature):
def __init__(self):
super(NPUDataDumpFeature, self).__init__("npu-datadump")
def register_args(self, parser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--npu-datadump', action='store_true', default=False,
help='enable npu data dump with mstt.')
def register_patches(self, patch_manager, args):
if args.npu_datadump:
try:
from msprobe.pytorch import PrecisionDebugger
except ImportError as e:
raise AssertionError('Mstt not found. You can install it with `pip install mindstudio-probe`.') from e
from mindspeed.functional.npu_datadump.npu_datadump import dump_start_wrapper, dump_end_wrapper
patch_manager.register_patch('megatron.training.training.train_step', dump_start_wrapper)
patch_manager.register_patch('megatron.training.ft_integration.on_training_step_end', dump_end_wrapper)