from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class AnsCompressTensorFeature(MindSpeedFeature):
def __init__(self):
super().__init__("compress-dense", 2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument("--compress-dense", type=str, default='disable',
choices=['disable', 'level0', 'level1'],
help='Compress activation in dense layer.')
def validate_args(self, args):
if args.compress_dense != "disable":
import torch_npu
if not hasattr(torch_npu, "npu_hans_encode") or not hasattr(torch_npu, "npu_hans_decode") \
or not hasattr(torch_npu, "empty_with_swapped_memory"):
raise AssertionError("`--compress-dense` is invalid, please update the latest PTA version.")
self.incompatible_check(args, "recompute_activation_function")
def register_patches(self, patch_manager, args):
if getattr(args, self.feature_name, "disable") != "disable":
from mindspeed.core.memory.compress_dense.adaptor import mindspeed_compress_dense_forward
patch_manager.register_patch('megatron.core.transformer.mlp.MLP.forward', mindspeed_compress_dense_forward)