import sys
from argparse import ArgumentParser
import torch
from mindspeed.features_manager.feature import MindSpeedFeature
class RequirementsBasicFeature(MindSpeedFeature):
def __init__(self):
super().__init__('requirements-basic', optimization_level=0)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--optimizer-selection', type=str, default='fused_adamw',
choices=['fused_adamw', 'fused_torch_adamw', 'fused_ema_adamw'],
help='Select from the former fused AdamW optimizer and Torch fused AdamW optimizer')
group.add_argument('--optimization-level', type=int, choices=[0, 1, 2], default=2,
help='0: The minimum patch set for megatron to adapt to NPU,'
'1: Affinity optimization (fusion operator, etc.), '
'2: Advanced acceleration algorithm')
def pre_register_patches(self, patch_manager, args):
self.te_adaptation(patch_manager, args)
self.apex_adaptation(patch_manager, args)
self.torch_adaptation(patch_manager, args)
self.optimizer_selection(patch_manager, args)
def te_adaptation(self, pm, args):
from mindspeed.core.megatron_basic.requirements_basic import version_wrapper, dummy_compile
from mindspeed.te.pytorch.module.layernorm import MindSpeedTELayernorm
pm.register_patch('torch.compile', dummy_compile)
pm.register_patch('torch.jit.script', dummy_compile)
pm.register_patch('importlib.metadata.version', version_wrapper)
pm.register_patch('transformer_engine.pytorch.LayerNorm', MindSpeedTELayernorm, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.LayerNormLinear', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.DotProductAttention', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.Linear', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.GroupedLinear', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.distributed.CudaRNGStatesTracker', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.common.recipe.DelayedScaling', torch.nn.Module, create_dummy=True)
pm.register_patch('flash_attn.flash_attn_interface.flash_attn_unpadded_func', create_dummy=True)
def apex_adaptation(self, pm, args):
from mindspeed.core.megatron_basic.requirements_basic import multi_tensor_l2norm, multi_tensor_scale, multi_tensor_applier
from mindspeed.core.fusions.fused_layer_norm import fused_layer_norm_affine
from mindspeed.ops.npu_matmul_add import npu_matmul_add_fp32, npu_matmul_add_fp16
from mindspeed.core.fusions.fused_layer_norm import FusedLayerNormAffineFunction, FastLayerNormFN
pm.register_patch('amp_C.multi_tensor_l2norm', multi_tensor_l2norm, create_dummy=True)
pm.register_patch('amp_C.multi_tensor_scale', multi_tensor_scale, create_dummy=True)
pm.register_patch('apex.multi_tensor_apply.multi_tensor_applier', multi_tensor_applier, create_dummy=True)
pm.register_patch('apex.normalization.fused_layer_norm.fused_layer_norm_affine', fused_layer_norm_affine, create_dummy=True)
pm.register_patch('fused_layer_norm_cuda', create_dummy=True)
pm.register_patch('fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32', npu_matmul_add_fp32, create_dummy=True)
pm.register_patch('fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16', npu_matmul_add_fp16, create_dummy=True)
pm.register_patch('apex.normalization.fused_layer_norm.FusedLayerNormAffineFunction',
FusedLayerNormAffineFunction, create_dummy=True)
def optimizer_selection(self, pm, args):
from mindspeed.core.optimizer.adamw import FusedTorchAdamW, AdamW
if args.optimizer_selection == 'fused_torch_adamw':
pm.register_patch('apex.optimizers.FusedAdam', FusedTorchAdamW, create_dummy=True)
elif args.optimizer_selection == 'fused_adamw':
pm.register_patch('apex.optimizers.FusedAdam', AdamW, create_dummy=True)
pm.register_patch('apex.optimizers.FusedSGD', torch.optim.SGD, create_dummy=True)
def torch_adaptation(self, pm, args):
from torch.distributed import all_gather_into_tensor, reduce_scatter_tensor
from mindspeed.core.megatron_basic.requirements_basic import type_wrapper, ensure_contiguous_wrapper, lcm, \
dummy_function, torch_all_reduce_double_dtype_bypass_wrapper
pm.register_patch('torch.nn.parameter.Parameter.type', type_wrapper)
pm.register_patch('torch.Tensor.type', type_wrapper)
pm.register_patch('torch.Tensor.view', ensure_contiguous_wrapper)
pm.register_patch('torch.distributed._all_gather_base', all_gather_into_tensor)
pm.register_patch('torch.distributed._reduce_scatter_base', reduce_scatter_tensor)
pm.register_patch('torch.distributed.all_reduce', torch_all_reduce_double_dtype_bypass_wrapper)
pm.register_patch('torch._C._jit_set_nvfuser_enabled', dummy_function)
if sys.version_info < (3, 9):
pm.register_patch('math.lcm', lcm, create_dummy=True)