from typing import Any
import torch
from torch import Tensor
import torch.distributed as dist
import megatron.core.parallel_state as mpu
from megatron.core.transformer.dot_product_attention import DotProductAttention as MegatronDotProductAttention
from mindspeed.megatron_adaptor import get_mindspeed_args
from mindspeed.core.context_parallel.model_parallel_utils import get_context_parallel_group_for_hybrid_ulysses
from mindspeed.core.context_parallel.dot_product_attention import CPDotProductAttentionImpl
try:
from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp.mapping import all_to_all
native_all_to_all = False
except ImportError:
from mindspeed_mm.models.common.communications import all_to_all
native_all_to_all = True
from mindspeed_mm.models.common.communications import cal_split_sizes
def get_hetero_dotproductattention(config):
if config.context_parallel_size > 1:
return HeteroCPDotProductAttention
else:
return MegatronDotProductAttention
class HeteroCPDotProductAttention(CPDotProductAttentionImpl, MegatronDotProductAttention):
def __init__(self, *args, scatter_idx=2, gather_idx=0, **kwargs):
CPDotProductAttentionImpl.__init__(self, *args, **kwargs)
config = self.config
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
if config.context_parallel_algo in ['hybrid_cp_algo', 'hybrid_adaptive_cp_algo', 'ulysses_cp_algo']:
self.spg = mpu.get_context_parallel_group()
if config.context_parallel_algo in ['hybrid_cp_algo', 'hybrid_adaptive_cp_algo']:
self.spg = get_context_parallel_group_for_hybrid_ulysses()
else:
raise NotImplementedError(f'algorithm {config.context_parallel_algo} not implemented yet')
self.DPA_forward = super().forward
self.spg_world_size = dist.get_world_size(self.spg)
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs: Any):
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
attention_mask = args[0]
packed_seq_params = kwargs['packed_seq_params']
if attention_mask is None:
act_seq_len = packed_seq_params.cu_seqlens_q[-1]
else:
act_seq_len = attention_mask.shape[-1]
if getattr(self.config, "use_remove_padding", False):
from mindspeed.utils import get_actual_seq_len
act_seq_len = get_actual_seq_len()[0]
attention_mask = torch.triu(
torch.ones([2048, 2048], dtype=torch.bool, device=query.device), diagonal=1
)
args_list = list(args)
args_list[0] = attention_mask
args = tuple(args_list)
if packed_seq_params is not None:
query = query.unsqueeze(1)
key = key.unsqueeze(1)
value = value.unsqueeze(1)
scatter_sizes_query = cal_split_sizes(query.shape[self.scatter_idx], self.spg_world_size)
scatter_sizes_key = cal_split_sizes(key.shape[self.scatter_idx], self.spg_world_size)
scatter_sizes_value = cal_split_sizes(value.shape[self.scatter_idx], self.spg_world_size)
gather_sizes = cal_split_sizes(act_seq_len, self.spg_world_size)
if not native_all_to_all:
query_layer = all_to_all(query, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
key_layer = all_to_all(key, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
value_layer = all_to_all(value, self.spg, self.scatter_idx, self.gather_idx, act_seq_len)
else:
query_layer = all_to_all(query, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_query, gather_sizes)
key_layer = all_to_all(key, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_key, gather_sizes)
value_layer = all_to_all(value, self.spg, self.scatter_idx, self.gather_idx, scatter_sizes_value, gather_sizes)
context_layer = self.DPA_forward(query_layer, key_layer, value_layer, *args, **kwargs)
if get_mindspeed_args().context_parallel_algo == "hybrid_cp_algo" and context_layer.dim() == 3:
context_layer = context_layer.unsqueeze(1)
else:
context_shape = context_layer.shape
context_layer = context_layer.reshape(context_shape[0], context_shape[1],
scatter_sizes_query[dist.get_rank(self.spg)], -1)
if not native_all_to_all:
output = all_to_all(context_layer, self.spg, self.gather_idx, self.scatter_idx, query.shape[self.scatter_idx])
else:
output = all_to_all(context_layer, self.spg, self.gather_idx, self.scatter_idx, gather_sizes, scatter_sizes_query)
output = output.reshape(output.shape[0], output.shape[1], -1)
if packed_seq_params is not None:
output = output.squeeze(1)
return output