from typing import Any
from mindspeed.features_manager.feature import MindSpeedFeature
class GenerateMaskFeature(MindSpeedFeature):
def __init__(self):
super().__init__(
'no-create-attention-mask-in-dataloader',
optimization_level=2
)
def is_need_apply(self, args: Any) -> bool:
"""Check the feature is need to apply."""
need_apply = True
if getattr(args, self.feature_name, None):
need_apply = False
return (
self.optimization_level <= args.optimization_level and
need_apply
) or self.default_patches
def validate_args(self, args):
args.create_attention_mask_in_dataloader = False
def register_patches(self, patch_manager, args):
_cp_algo = getattr(args, 'context_parallel_algo', 'megatron_cp_algo')
_cp_expanded_by_2d_tp = getattr(args, 'tp_2d', False) and getattr(args, 'tp_y', 1) > 1
if int(getattr(args, 'context_parallel_size', 1)) < 2 and not (_cp_expanded_by_2d_tp and _cp_algo == 'megatron_cp_algo'):
from mindspeed.core.transformer.flash_attention.generate_mask.adaptor import dot_product_attention_forward_wrapper
patch_manager.register_patch(
'megatron.core.transformer.dot_product_attention.DotProductAttention.forward',
dot_product_attention_forward_wrapper
)