from typing import Optional
import torch
try:
import torch_npu
except ImportError:
pass
from torch import nn
from mindspeed_llm.fsdp2.distributed.context_parallel.utils import gather_heads_scatter_seq, \
gather_seq_scatter_heads
from mindspeed_llm.fsdp2.distributed.parallel_state import ParallelState
def fixed_cross_entropy_with_cp(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
**kwargs,
) -> torch.Tensor:
ps = ParallelState()
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if ps.get_group_size("cp") > 1:
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ps.get_group("cp"))
if reduction == "sum":
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(loss.device)
loss = loss / num_items_in_batch
return loss
def flash_attention_forward_fa(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
ps = ParallelState()
pre_tokens = 1048576
next_tokens = 0
sparse_mode = 4
shape_order = "BNSD"
num_groups = int(module.config.num_attention_heads / module.config.num_key_value_heads)
if ps.context_parallel_size > module.config.num_key_value_heads:
key = torch.repeat_interleave(key, dim=1, repeats=num_groups)
value = torch.repeat_interleave(value, dim=1, repeats=num_groups)
if ps.context_parallel_size > 1:
query = gather_seq_scatter_heads(query, seq_dim=2, head_dim=1,
gather_size=query.shape[2] * ps.context_parallel_size)
key = gather_seq_scatter_heads(key, seq_dim=2, head_dim=1, gather_size=key.shape[2] * ps.context_parallel_size)
value = gather_seq_scatter_heads(value, seq_dim=2, head_dim=1,
gather_size=value.shape[2] * ps.context_parallel_size)
sinks = torch.chunk(module.sinks, ps.context_parallel_size)[ps.get_rank("cp")]
else:
sinks = module.sinks
bsz, n_head, seq_length, head_dim = (
query.shape[0], query.shape[1], query.shape[2], query.shape[3])
if module.sliding_window:
pre_tokens = module.sliding_window
new_mask = torch.ones((2048, 2048), device=torch.accelerator.current_device(), dtype=torch.bool)
atten_mask = torch.triu(new_mask, diagonal=1)
attn_output = torch_npu.npu_fusion_attention_v2(
query, key, value,
n_head,
shape_order,
pse=None,
sparse_mode=sparse_mode,
sink=sinks.float(),
atten_mask=atten_mask,
scale=scaling,
pre_tokens=pre_tokens,
next_tokens=next_tokens,
keep_prob=1 - dropout,
)[0]
if ps.context_parallel_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, head_dim=1, seq_dim=2,
gather_size=module.config.num_attention_heads)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def flash_attention_forward_fa_gqa(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
ps = ParallelState()
num_groups = int(module.config.num_attention_heads / module.config.num_key_value_heads)
if num_groups > 1:
key = torch.repeat_interleave(key, dim=1, repeats=num_groups)
value = torch.repeat_interleave(value, dim=1, repeats=num_groups)
if ps.context_parallel_size > 1:
query = gather_seq_scatter_heads(query, seq_dim=2, head_dim=1,
gather_size=query.shape[2] * ps.context_parallel_size)
key = gather_seq_scatter_heads(key, seq_dim=2, head_dim=1, gather_size=key.shape[2] * ps.context_parallel_size)
value = gather_seq_scatter_heads(value, seq_dim=2, head_dim=1,
gather_size=value.shape[2] * ps.context_parallel_size)
input_layout = "BNSD"
new_mask = torch.ones((2048, 2048), device=torch.accelerator.current_device(), dtype=torch.bool)
atten_mask = torch.triu(new_mask, diagonal=1)
attn_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=query.shape[1],
input_layout=input_layout,
atten_mask=atten_mask,
keep_prob=1 - dropout,
scale=scaling,
sparse_mode=2
)[0]
if ps.context_parallel_size > 1:
attn_output = gather_heads_scatter_seq(attn_output, head_dim=1, seq_dim=2,
gather_size=module.config.num_attention_heads)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None