from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class UlyssesContextParallelFeature(MindSpeedFeature):
def __init__(self):
super().__init__('context-parallel-size')
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
self.add_parser_argument_choices_value(
parser,
"--context-parallel-algo",
'ulysses_cp_algo'
)
group.add_argument('--ulysses-degree-in-cp', type=int, default=None)
def validate_args(self, args):
if args.context_parallel_size > 1 and args.context_parallel_algo == 'ulysses_cp_algo':
if args.seq_length % args.context_parallel_size != 0:
raise AssertionError("sequence length must be divisible by context_parallel_size")
head, remainder = divmod(args.num_attention_heads,
args.context_parallel_size * args.tensor_model_parallel_size)
if not (head >= 1 and remainder == 0):
raise AssertionError("num_attention_heads must be divisible by context_parallel_size * tensor_model_parallel_size")
args.use_flash_attn = True