from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class AlibiFeature(MindSpeedFeature):
"""
Attention positional embedding.
To enable this feature, the reference is as follows .
Usage:
"--position-embedding-type alibi"
"""
def __init__(self):
super().__init__('position-embedding-type', optimization_level=2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
self.add_parser_argument_choices_value(parser, "--position-embedding-type", 'alibi')
group.add_argument('--square-alibi-mask', action='store_true', default=False,
help='attention mask of alibi is squared')
group.add_argument('--fill-neg-inf', action='store_true', default=False,
help='fill alibi with negative inf')
group.add_argument('--alibi-fusion-attn-type', type=int, default=None,
help='alibi pse type, support for 0,2')
def register_patches(self, patch_manager, args):
from mindspeed_llm.core.transformer.alibi_attention import AlibiAttention
if getattr(args, "position_embedding_type", None) == "alibi" and not getattr(args, "use_flash_attn", False) and not getattr(args, "context_parallel_size", 1) > 1 :
patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention',
AlibiAttention)
patch_manager.register_patch('megatron.core.transformer.custom_layers.transformer_engine.TEDotProductAttention',
AlibiAttention)