import warnings
from argparse import ArgumentParser, Namespace
import torch
import torch_npu
from mindspeed.features_manager.feature import MindSpeedFeature
class FusedMoEPermuteFeature(MindSpeedFeature):
def __init__(self):
super().__init__('moe-permute-fusion', optimization_level=2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title='fusion')
group.add_argument("--use-fused-moe-token-permute-and-unpermute", action='store_true',
help="Use fused moe permute and unpermute.")
def is_need_apply(self, args):
return (self.optimization_level <= args.optimization_level and
(getattr(args, self.feature_name, None)
or getattr(args, "use_fused_moe_token_permute_and_unpermute"), None))
def validate_args(self, args: Namespace):
hasattr_npu_permute = hasattr(torch_npu, "npu_moe_token_permute_with_routing_map")
hasattr_npu_unpermute = hasattr(torch_npu, "npu_moe_token_unpermute_with_routing_map")
npu_attributes_missing = not hasattr_npu_permute or not hasattr_npu_unpermute
moe_permute_fusion_enabled = args.use_fused_moe_token_permute_and_unpermute or args.moe_permute_fusion
if npu_attributes_missing and moe_permute_fusion_enabled:
warnings.warn("torch_npu should have attribute npu_moe_token_permute_with_routing_map and "
"npu_moe_token_unpermute_with_routing_map, but at least one of these attributes is missing. "
"Please upgrade CANN to 8.3.RC1 and higher, and PTA to 7.2.RC1 and higher to enable --moe-permute-fusion. "
"This run will set --moe-permute-fusion to False.")
args.moe_permute_fusion = False
args.use_fused_moe_token_permute_and_unpermute = False
return
if args.use_fused_moe_token_permute_and_unpermute and not args.moe_permute_fusion:
args.moe_permute_fusion = True
if not args.use_fused_moe_token_permute_and_unpermute and args.moe_permute_fusion:
args.use_fused_moe_token_permute_and_unpermute = True
if args.moe_permute_fusion and args.moe_token_dispatcher_type == 'allgather':
raise AssertionError(
"Only alltoall and alltoall_seq dispatcher supports --moe-permute-fusion(or --use-fused-moe-token-permute-and-unpermute"
"). allgather dispatcher don't support it.")
if args.moe_permute_fusion and args.moe_expert_capacity_factor and not args.moe_pad_expert_input_to_capacity:
raise AssertionError(
"The --moe-permute-fusion can only be enabled if '--moe-expert-capacity-factor' and "
"'--moe-pad-expert-input-to-capacity' are enabled at the same time. If only "
"--moe-expert-capacity-factor is enabled, the moe permute fusion is not supported.")
def pre_register_patches(self, pm, args):
pm.register_patch('transformer_engine.pytorch.permutation.moe_permute', torch.nn.Module, create_dummy=True)
pm.register_patch('transformer_engine.pytorch.permutation.moe_permute_with_probs', torch.nn.Module,
create_dummy=True)
pm.register_patch('transformer_engine.pytorch.permutation.moe_sort_chunks_by_index', torch.nn.Module,
create_dummy=True)
pm.register_patch('transformer_engine.pytorch.permutation.moe_sort_chunks_by_index_with_probs',
torch.nn.Module,
create_dummy=True)
pm.register_patch('transformer_engine.pytorch.permutation.moe_unpermute', torch.nn.Module,
create_dummy=True)
def register_patches(self, patch_manager, args: Namespace):
if getattr(args, self.feature_name, None) or getattr(args, "use_fused_moe_token_permute_and_unpermute", None):
warnings.warn(
"Parameters --moe-permute-fusion and --use-fused-moe-token-permute-and-unpermute are equivalent. "
"Use only one; prefer --moe-permute-fusion.")
from mindspeed.core.fusions.fused_moe_permute import permute, unpermute, sort_chunks_by_idxs_wrapper
patch_manager.register_patch('megatron.core.transformer.moe.moe_utils.permute', permute)
patch_manager.register_patch('megatron.core.transformer.moe.moe_utils.unpermute', unpermute)
try:
from mindspeed.lite.ops.triton.sort_chunks_by_idx import moe_sort_chunks_by_index, \
moe_sort_chunks_by_index_with_probs
patch_manager.register_patch('megatron.core.extensions.transformer_engine.fused_sort_chunks_by_index',
moe_sort_chunks_by_index)
patch_manager.register_patch('megatron.core.extensions.transformer_engine.fused_sort_chunks_by_index_with_probs',
moe_sort_chunks_by_index_with_probs)
except ImportError as e:
warnings.warn("Currently, fused_sort_chunks_by_index is not supported")
patch_manager.register_patch('megatron.core.transformer.moe.moe_utils.sort_chunks_by_idxs',
sort_chunks_by_idxs_wrapper)
if getattr(args, "moe_token_dispatcher_type", None) == "alltoall":
from mindspeed.core.fusions.fused_moe_permute import (
moe_alltoall_token_dispatcher_init_wrapper, maybe_dtoh_and_synchronize)
patch_manager.register_patch(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher._maybe_dtoh_and_synchronize',
maybe_dtoh_and_synchronize)
patch_manager.register_patch(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.__init__',
moe_alltoall_token_dispatcher_init_wrapper)
if getattr(args, "moe_token_dispatcher_type", None) == "alltoall":
from mindspeed.core.fusions.fused_moe_permute import preprocess_sync_wrapper
patch_manager.register_patch(
'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.preprocess',
preprocess_sync_wrapper)
elif getattr(args, "moe_token_dispatcher_type", None) == "alltoall_seq":
from mindspeed.core.fusions.fused_moe_permute import (
transformer_config_post_init_wrapper, alltoall_seq_token_permutation, alltoall_seq_token_unpermutation)
patch_manager.register_patch("megatron.core.transformer.transformer_config.TransformerConfig.__post_init__",
transformer_config_post_init_wrapper)
patch_manager.register_patch("megatron.core.transformer.moe.legacy_a2a_token_dispatcher.MoEAlltoAllSEQTokenDispatcher.token_permutation",
alltoall_seq_token_permutation)
patch_manager.register_patch("megatron.core.transformer.moe.legacy_a2a_token_dispatcher.MoEAlltoAllSEQTokenDispatcher.token_unpermutation",
alltoall_seq_token_unpermutation)