from logging import getLogger
from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class AlibiFeature(MindSpeedFeature):
"""
Attention positional embedding.
To enable this feature, the reference is as follows .
Usage:
"--position-embedding-type alibi"
"--alibi-fusion-attn-type 0 or 2"
"[--alibi-diagonal-opposite]"
"""
def __init__(self):
super().__init__(
'position-embedding-type',
optimization_level=2
)
def is_need_apply(self, args):
pse = getattr(args, self.feature_name, None)
need_apply = False
if pse == 'alibi':
need_apply = True
return (
self.optimization_level <= args.optimization_level and
need_apply
) or self.default_patches
def register_args(self, parser: ArgumentParser):
self.add_parser_argument_choices_value(
parser,
"--position-embedding-type",
'alibi'
)
group = parser.add_argument_group(title='alibi')
group.add_argument(
'--square-alibi-mask',
action='store_true',
default=False,
help='attention mask of alibi is squared'
)
group.add_argument(
'--fill-neg-inf',
action='store_true',
default=False,
help='fill alibi with negative inf'
)
group.add_argument(
'--alibi-fusion-attn-type',
type=int,
help='alibi pse type, support for 0,2'
)
group.add_argument(
'--alibi-diagonal-opposite',
action='store_true',
default=False,
help='make alibi diagonal opposite'
)
def validate_args(self, args):
if args.alibi_fusion_attn_type is not None:
if args.alibi_fusion_attn_type not in [0, 2]:
raise AssertionError(
'--alibi-fusion-attn-type only support for `0, 2`'
)
if args.alibi_fusion_attn_type in [0, 2]:
args.use_fusion_attn_v2 = True
def register_patches(self, patch_manager, args):
if int(getattr(args, 'context_parallel_size', 1)) == 1:
from mindspeed.core.transformer.flash_attention.alibi.adaptor import MindSpeedDotProductAttention
patch_manager.register_patch(
'megatron.core.transformer.dot_product_attention.DotProductAttention',
MindSpeedDotProductAttention
)