from typing import Any, Tuple, List
import torch
from torch import Tensor
from torch.nn import Module
import torch.distributed as dist
from megatron.core import mpu
from megatron.training import get_args
from mindspeed.megatron_adaptor import get_mindspeed_args
from mindspeed.patch_utils import MindSpeedPatchesManager as pm
try:
from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import all_to_all
native_all_to_all = False
except ImportError:
from mindspeed_mm.models.common.communications import all_to_all
native_all_to_all = True
from mindspeed_mm.models.common.communications import cal_split_sizes
class UlyssesContextAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
"""
def __init__(
self,
local_attention: Module,
sequence_process_group: torch.distributed.ProcessGroup,
scatter_idx: int = 2,
gather_idx: int = 0,
) -> None:
super(UlyssesContextAttention, self).__init__()
self.local_attn = local_attention
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs: Any) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
attention_mask = args[0]
packed_seq_params = kwargs['packed_seq_params']
if attention_mask is None:
act_seq_len = packed_seq_params.cu_seqlens_q[-1]
else:
act_seq_len = attention_mask.shape[-1]
if getattr(self.local_attn.config, "use_remove_padding", False):
from mindspeed.utils import get_actual_seq_len
act_seq_len = get_actual_seq_len()[0]
attention_mask = torch.triu(
torch.ones([2048, 2048], dtype=torch.bool, device=query.device), diagonal=1
)
args_list = list(args)
args_list[0] = attention_mask
args = tuple(args_list)
if packed_seq_params is not None:
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
scatter_sizes_query = cal_split_sizes(query.shape[self.scatter_idx], dist.get_world_size(self.spg))
scatter_sizes_key = cal_split_sizes(key.shape[self.scatter_idx], dist.get_world_size(self.spg))
scatter_sizes_value = cal_split_sizes(value.shape[self.scatter_idx], dist.get_world_size(self.spg))
gather_sizes = cal_split_sizes(act_seq_len, dist.get_world_size(self.spg))
if not native_all_to_all:
query_layer = all_to_all(query, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
key_layer = all_to_all(key, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
value_layer = all_to_all(value, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
else:
query_layer = all_to_all(query, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_query, gather_sizes)
key_layer = all_to_all(key, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_key, gather_sizes)
value_layer = all_to_all(value, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_value, gather_sizes)
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
if get_mindspeed_args().context_parallel_algo == "hybrid_cp_algo" and context_layer.dim() == 3:
context_layer = context_layer.unsqueeze(1)
else:
context_shape = context_layer.shape
context_layer = context_layer.reshape(context_shape[0], context_shape[1],
scatter_sizes_query[dist.get_rank(self.spg)], -1)
if not native_all_to_all:
output = all_to_all(context_layer, self.spg, self.gather_idx, self.scatter_idx, query.shape[self.scatter_idx])
else:
output = all_to_all(context_layer, self.spg, self.gather_idx, self.scatter_idx, gather_sizes, scatter_sizes_query)
output = output.reshape(output.shape[0], output.shape[1], -1)
if packed_seq_params is not None:
output = output.squeeze(1)
return output
def get_rotary_seq_len(
self,
inference_params,
transformer,
transformer_input,
transformer_config
) -> float:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
if transformer.input_tensor is not None:
rotary_seq_len = transformer.input_tensor.size(0)
else:
rotary_seq_len = transformer_input.size(0)
if hasattr(get_args(), "context_parallel_size"):
if int(get_args().context_parallel_size) > 1:
rotary_seq_len = torch.Tensor([rotary_seq_len]).cuda().int()
torch.distributed.all_reduce(rotary_seq_len, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
rotary_seq_len = rotary_seq_len.item()
return rotary_seq_len
mindspeed_args = get_mindspeed_args()
if hasattr(mindspeed_args, 'context_parallel_algo') and hasattr(mindspeed_args, 'context_parallel_size'):
if mindspeed_args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"] and int(mindspeed_args.context_parallel_size) > 1:
pm.register_patch('mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel.UlyssesContextAttention',
UlyssesContextAttention, force_patch=True)
pm.register_patch('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.get_rotary_seq_len',
get_rotary_seq_len, force_patch=True)
pm.apply_patches()