from argparse import ArgumentParser
from mindspeed.features_manager.context_parallel.context_parallel_feature import ContextParallelFeature as MindspeedContextParallelFeature
class ContextParallelFeature(MindspeedContextParallelFeature):
def __init__(self):
super().__init__()
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--context-parallel-algo', type=str, default='megatron_cp_algo',
choices=['megatron_cp_algo', 'hybrid_cp_algo', 'kvallgather_cp_algo'],
help='context parallel algorithm')
group.add_argument('--cp-window-size', type=int, default=1)
group.add_argument('--attention-mask-type', type=str, default='causal',
choices=['causal', 'general'], help='attention mask type')
group.add_argument('--use-cp-send-recv-overlap', action='store_true',
help='use this flag to enable cp send-recv-overlap.')
group.add_argument("--use-fused-ring-attention-update", action='store_true',
help="Use fused ring attention update.")
group.add_argument("--megatron-cp-in-bnsd", action='store_true',
help="Megatron CP in bnsd.")
def validate_args(self, args):
super().validate_args(args)
if args.context_parallel_size > 1:
if args.position_embedding_type == 'alibi':
raise AssertionError("Context parallel does not support alibi")
if args.use_kv_cache:
raise AssertionError("Context parallel does not support use_kv_cache")
if args.sliding_window is not None and args.seq_length > args.sliding_window:
raise AssertionError("Context parallel does not support sliding_windows")
if args.context_parallel_size > 1 and args.context_parallel_algo == 'kvallgather_cp_algo':
if args.attention_mask_type != "causal":
raise AssertionError("kvallgather_cp_algo only supports causal attention mask type")
if not getattr(args, 'reset_attention_mask', False):
if hasattr(args, 'seq_length') and args.seq_length % (2 * args.context_parallel_size) != 0:
raise AssertionError("sequence length must be divisible by 2 * context_parallel_size in kvallgather_cp_algo with SBHD format")
else:
if hasattr(args, 'seq_length') and args.seq_length % args.context_parallel_size != 0:
raise AssertionError("sequence length must be divisible by context_parallel_size in kvallgather_cp_algo with THD format")
def register_patches(self, patch_manager, args):
if int(getattr(args, 'context_parallel_size', 1)) > 1 and not getattr(args, 'reset_attention_mask', None):
from mindspeed.core.context_parallel.model_parallel_utils import initialize_model_parallel_cp_wrapper, \
destroy_model_parallel_cp_wrapper, get_context_parallel_group_for_send_recv_overlap
from mindspeed.core.context_parallel.rotary_pos_embedding_utils import get_pos_emb_on_this_cp_rank
from mindspeed_llm.core.context_parallel.adaptor import CPDotProductAttention
from mindspeed_llm.core.context_parallel.adaptor import attention_init_wrapper
from mindspeed_llm.core.context_parallel.get_batch_utils import get_batch_on_this_cp_rank
patch_manager.register_patch('megatron.core.parallel_state.initialize_model_parallel',
initialize_model_parallel_cp_wrapper)
patch_manager.register_patch('megatron.core.parallel_state.destroy_model_parallel',
destroy_model_parallel_cp_wrapper)
patch_manager.register_patch('megatron.core.parallel_state.get_context_parallel_group_for_send_recv_overlap',
get_context_parallel_group_for_send_recv_overlap)
patch_manager.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.get_pos_emb_on_this_cp_rank',
get_pos_emb_on_this_cp_rank)
patch_manager.register_patch('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank)
patch_manager.register_patch('megatron.core.transformer.attention.Attention.__init__',
attention_init_wrapper)
patch_manager.register_patch('megatron.core.transformer.dot_product_attention.DotProductAttention',
CPDotProductAttention)
if getattr(args, 'context_parallel_algo', 'megatron_cp_algo') == 'kvallgather_cp_algo':
from mindspeed_llm.core.transformer.custom_dot_product_attention import CustomDotProductAttention
patch_manager.register_patch(
'megatron.core.transformer.dot_product_attention.DotProductAttention.forward',
CustomDotProductAttention.forward)
if args.transformer_impl == 'transformer_engine':
from mindspeed_llm.te.pytorch.attention.dot_product_attention.te_cp_dot_product_attention import TECPDotProductAttention
patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEDotProductAttention',
TECPDotProductAttention)
else:
patch_manager.register_patch('megatron.core.extensions.transformer_engine.TEDotProductAttention',
CPDotProductAttention)