from typing import Tuple, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_npu
from einops import rearrange, repeat
from megatron import core
from megatron.core import mpu, tensor_parallel
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args
from megatron.legacy.model.rms_norm import RMSNorm
from megatron.legacy.model.enums import AttnType
from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed.core.context_parallel.ring_context_parallel.ring_context_parallel import ringattn_context_parallel
from mindspeed.core.context_parallel.model_parallel_utils import (
get_context_parallel_group_for_hybrid_ulysses,
get_context_parallel_group_for_hybrid_ring,
get_context_parallel_for_hybrid_ring_world_size,
get_context_parallel_for_hybrid_ring_rank,
get_context_parallel_for_hybrid_ring_global_ranks,
get_ring_ranks_for_intra_window,
get_ring_ranks_for_inter_window_kv,
get_ring_ranks_for_inter_window_dkv,
get_ring_group_for_intra_window,
get_ring_group_for_intra_window_send_recv_overlap
)
from mindspeed.core.context_parallel.ulysses_context_parallel.unaligned_cp import mapping
from mindspeed_mm.utils.utils import video_to_image, change_tensor_layout
from mindspeed_mm.models.common.normalize import normalize, FP32LayerNorm
from mindspeed_mm.models.common.embeddings.rope import RoPE3D, PositionGetter3D
from mindspeed_mm.models.common.conv import CausalConv3d, WfCausalConv3d
from mindspeed_mm.models.common.linear import MatmulAddLinear
from mindspeed_mm.utils.async_offload import async_save_on_cpu
from .communications import (
all_to_all,
all_to_all_SBH,
split_forward_gather_backward,
)
def do_ring_context_parallel(q, k, v, head_num, softmax_scale, attn_mask, dropout_p=0., pse=None, pse_type=None):
args = get_args()
in_hybrid_mode = get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None
if in_hybrid_mode:
cp_group = get_context_parallel_group_for_hybrid_ring()
cp_size = get_context_parallel_for_hybrid_ring_world_size()
rank = get_context_parallel_for_hybrid_ring_rank()
cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
else:
cp_group = mpu.get_context_parallel_group()
cp_size = mpu.get_context_parallel_world_size()
rank = mpu.get_context_parallel_rank()
cp_global_ranks = mpu.get_context_parallel_global_ranks()
cp_para = dict()
cp_para['causal'] = args.attention_mask_type == 'causal'
cp_para['cp_group'] = cp_group
cp_para['cp_size'] = cp_size
cp_para['rank'] = rank
cp_para['cp_global_ranks'] = cp_global_ranks
cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap() \
if args.use_cp_send_recv_overlap else None
cp_para['pse'] = pse
cp_para['pse_type'] = pse_type
cp_para['megatron_cp_in_bnsd'] = args.megatron_cp_in_bnsd
output = ringattn_context_parallel(q, k, v, head_num, cp_para, softmax_scale, attn_mask, dropout_p)
return output
def do_npu_fusion_attention(
q, k, v,
head_num,
softmax_scale,
layout="BNSD",
attn_mask=None,
actual_seq_qlen=None,
actual_seq_kvlen=None,
dropout_p=0.,
pse=None,
sparse_mode=0,
async_offload=False,
block_idx=0,
depth=0,
h2d_stream=None,
d2h_stream=None
):
if async_offload:
with async_save_on_cpu(
h2d_stream=h2d_stream,
d2h_stream=d2h_stream,
block_idx=block_idx,
depth=depth,
custom_check_fn=lambda x: x.storage().size() >= q.storage().size()
):
output = torch_npu.npu_fusion_attention(
q, k, v, head_num, layout,
pse=pse,
padding_mask=None,
atten_mask=attn_mask,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_kvlen,
scale=softmax_scale,
keep_prob=1 - dropout_p,
inner_precise=0,
sparse_mode=sparse_mode
)[0]
else:
output = torch_npu.npu_fusion_attention(
q, k, v, head_num, layout,
pse=pse,
padding_mask=None,
atten_mask=attn_mask,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_kvlen,
scale=softmax_scale,
keep_prob=1 - dropout_p,
inner_precise=0,
sparse_mode=sparse_mode
)[0]
return output
class FlashAttention(nn.Module):
"""Implement the multihead softmax attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
fa_layout: The input layout in Flash attention.
"""
def __init__(
self,
softmax_scale=None,
attention_dropout=0.0,
fa_layout="sbh"
):
super().__init__()
args = get_args()
self.softmax_scale = softmax_scale
self.attention_dropout = attention_dropout
self.fa_layout = fa_layout
self.pse = None
self.pse_type = 1
self.sparse_mode = args.sparse_mode
self.context_parallel_algo = args.context_parallel_algo
self.context_parallel_size = args.context_parallel_size
def forward(self, query, key, value, attention_mask=None, **kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (S, B, N, D)
"""
seq_length, batch_size, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
if attention_mask is not None and self.context_parallel_algo not in ["megatron_cp_algo", "hybrid_cp_algo"]:
attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])
if self.context_parallel_size > 1 and self.context_parallel_algo in ["megatron_cp_algo", "hybrid_cp_algo"]:
if self.fa_layout.lower() != "sbh":
raise ValueError(f"Flash attention layout mulst be `sbh` when using Ring Attention, but got {self.fa_layout}!")
query, key, value = [rearrange(x, "s b n d -> s b (n d)") for x in [query, key, value]]
output = do_ring_context_parallel(
query,
key,
value,
head_num=n_head,
softmax_scale=self.softmax_scale,
attn_mask=attention_mask,
pse=self.pse,
pse_type=self.pse_type
)
else:
query, key, value = [change_tensor_layout(x, "sbnd", self.fa_layout, batch_size=batch_size) for x in [query, key, value]]
actual_seq_qlen = []
actual_seq_kvlen = []
if self.fa_layout == "tnd":
if attention_mask is not None:
ans = 0
for _ in range(batch_size):
ans += seq_length
actual_seq_qlen.append(ans)
ans = 0
for m in attention_mask:
ans += m
actual_seq_kvlen.append(ans)
else:
if attention_mask is not None:
attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1])
output = do_npu_fusion_attention(
query, key, value,
head_num=n_head,
softmax_scale=self.softmax_scale,
layout=self.fa_layout.upper(),
attn_mask=attention_mask,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_kvlen,
dropout_p=self.attention_dropout,
sparse_mode=self.sparse_mode,
**kwargs,
)
output = change_tensor_layout(output, self.fa_layout, "sbh", batch_size=batch_size)
return output
class ParallelAttention(nn.Module):
"""
A multi-head attention layer for both self-atten and cross-atten".
Args:
query_dim: The number of channels in the query.
key_dim: The number of channels in the key, defaults to `query_dim`.
num_attention_heads: The number of heads to use for multi-head attention.
hidden_size: The hidden layer size.
proj_q_bias: Whether to use bias in query projection.
proj_k_bias: Whether to use bias in key projection.
proj_v_bias: Whether to use bias in value projection.
proj_out_bias: Whether to use bias in out projection.
dropout: The dropout probability to use.
use_qk_norm: Whether to use normalization for q and k.
norm_type: The type of normalization layer.
norm_elementwise_affine: A boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
norm_eps: A value added to the denominator for numerical stability. Default: 1e-5
is_qkv_concat: Whether to concatenate qkv in projection.
attention_type: Self attention or cross attention.
is_kv_concat: Whether to concatenate kv in projection.
fa_layout: The input layout of Flash Attention.
rope: The Rotary Position Embedding object, default to `None`.
split_kv_in_forward: Whether the input kv in the forward function is split or not.
This argument is valid only in ["ulysses_cp_algo", "hybrid_cp_algo"].
"""
def __init__(
self,
query_dim: int,
key_dim: Optional[int],
num_attention_heads: int,
hidden_size: int,
proj_q_bias: bool = False,
proj_k_bias: bool = False,
proj_v_bias: bool = False,
proj_out_bias: bool = False,
dropout: float = 0.0,
use_qk_norm: bool = False,
norm_type: str = None,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
is_qkv_concat: bool = False,
attention_type: int = AttnType.self_attn,
is_kv_concat: bool = False,
fa_layout: str = "sbh",
rope=None,
split_kv_in_forward: bool = True,
**kwargs
):
super().__init__()
args = get_args()
config = core_transformer_config_from_args(args)
key_dim = key_dim if key_dim is not None else query_dim
self.sequence_parallel = args.sequence_parallel
self.is_qkv_concat = is_qkv_concat
self.is_kv_concat = is_kv_concat
self.use_qk_norm = use_qk_norm
self.attention_type = attention_type
self.fa_layout = fa_layout
self.rope = rope
self.cp_size = mpu.get_context_parallel_world_size()
self.tp_size = mpu.get_tensor_model_parallel_world_size()
self.head_dim = core.utils.divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(num_attention_heads, self.tp_size)
self.attention_mask_type = args.attention_mask_type
if self.attention_type == AttnType.self_attn and self.is_qkv_concat:
self.proj_qkv = tensor_parallel.ColumnParallelLinear(
query_dim,
hidden_size * 3,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False)
elif self.attention_type == AttnType.cross_attn and self.is_kv_concat:
self.proj_q = tensor_parallel.ColumnParallelLinear(
query_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False)
self.proj_kv = tensor_parallel.ColumnParallelLinear(
key_dim,
hidden_size * 2,
config=config,
init_method=config.init_method,
bias=proj_k_bias,
gather_output=False)
else:
self.proj_q = tensor_parallel.ColumnParallelLinear(
query_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False
)
self.proj_k = tensor_parallel.ColumnParallelLinear(
key_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_k_bias,
gather_output=False
)
self.proj_v = tensor_parallel.ColumnParallelLinear(
key_dim,
hidden_size,
config=config,
init_method=config.init_method,
bias=proj_v_bias,
gather_output=False
)
if self.use_qk_norm:
self.q_norm = normalize(
norm_type=norm_type,
in_channels=self.head_dim,
eps=norm_eps,
affine=norm_elementwise_affine,
**kwargs
)
self.k_norm = normalize(
norm_type=norm_type,
in_channels=self.head_dim,
eps=norm_eps,
affine=norm_elementwise_affine,
**kwargs
)
if isinstance(self.q_norm, nn.LayerNorm):
for param in self.q_norm.parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
if isinstance(self.k_norm, nn.LayerNorm):
for param in self.k_norm.parameters():
setattr(param, "sequence_parallel", self.sequence_parallel)
self.core_attention_flash = FlashAttention(
attention_dropout=dropout,
fa_layout=self.fa_layout,
softmax_scale=1 / math.sqrt(self.head_dim)
)
if mpu.get_context_parallel_world_size() > 1 and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]:
ulysses_group = mpu.get_context_parallel_group()
if args.context_parallel_algo == "hybrid_cp_algo":
ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
if split_kv_in_forward:
self.core_attention_flash = UlyssesContextAttention(self.core_attention_flash, ulysses_group)
self.proj_out = tensor_parallel.RowParallelLinear(
hidden_size,
query_dim,
config=config,
init_method=config.init_method,
bias=proj_out_bias,
input_is_parallel=True,
skip_bias_add=False
)
self.dropout = nn.Dropout(dropout)
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensor
from `hidden_states` or `key_value_states`.
"""
if self.attention_type == AttnType.self_attn and self.is_qkv_concat:
mixed_qkv = self.proj_qkv(hidden_states)[0]
(query, key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_qkv, 3)
elif self.attention_type == AttnType.cross_attn and self.is_kv_concat:
query = self.proj_q(hidden_states)[0]
mixed_kv = self.proj_kv(key_value_states)[0]
(key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)
else:
query = self.proj_q(hidden_states)[0]
key = self.proj_k(key_value_states)[0]
value = self.proj_v(key_value_states)[0]
batch_size = query.shape[1]
query = query.view(-1, batch_size, self.num_attention_heads_per_partition, self.head_dim)
key = key.view(-1, batch_size, self.num_attention_heads_per_partition, self.head_dim)
value = value.view(-1, batch_size, self.num_attention_heads_per_partition, self.head_dim)
if self.use_qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)
return query, key, value
def function_before_core_attention(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
input_layout: str = "sbh",
rotary_pos_emb: Optional[torch.Tensor] = None
):
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
query = change_tensor_layout(query, input_layout, "sbh")
key = query if key is None else change_tensor_layout(key, input_layout, "sbh")
query, key, value = self.get_query_key_value_tensors(query, key)
if self.rope is not None and rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query = self.rope.apply_rotary_pos_emb(query, q_pos_emb)
key = self.rope.apply_rotary_pos_emb(key, k_pos_emb)
return query, key, value
def function_after_core_attention(
self,
core_attn_out,
output_layout: str = "sbh"
):
output, bias = self.proj_out(core_attn_out)
output = change_tensor_layout(output, "sbh", output_layout)
output = self.dropout(output)
return output
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
input_layout: str = "sbh",
rotary_pos_emb: Optional[torch.Tensor] = None
):
query, key, value = self.function_before_core_attention(query, key, input_layout, rotary_pos_emb)
core_attn_out = self.core_attention_flash(query, key, value, mask)
out = self.function_after_core_attention(core_attn_out, input_layout)
return out
class MultiHeadSparseAttentionSBH(ParallelAttention):
def __init__(
self,
query_dim: int,
key_dim: Optional[int],
num_attention_heads: int,
hidden_size: int,
proj_q_bias: bool = False,
proj_k_bias: bool = False,
proj_v_bias: bool = False,
proj_out_bias: bool = True,
dropout: float = 0.0,
use_qk_norm: bool = False,
norm_type: str = None,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
is_qkv_concat: bool = False,
attention_type: int = AttnType.self_attn,
is_kv_concat: bool = False,
fa_layout: str = "sbh",
rope=None,
sparse1d=False,
sparse_n=None,
sparse_group=None,
**kwargs
):
super().__init__(
query_dim,
key_dim,
num_attention_heads,
hidden_size,
proj_q_bias,
proj_k_bias,
proj_v_bias,
proj_out_bias,
dropout,
use_qk_norm,
norm_type,
norm_elementwise_affine,
norm_eps,
is_qkv_concat,
attention_type,
is_kv_concat,
fa_layout,
rope,
**kwargs
)
args = get_args()
self.sparse1d = sparse1d
self.sparse_n = sparse_n
self.sparse_group = sparse_group
if args.context_parallel_algo == 'ulysses_cp_algo':
self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, mpu.get_context_parallel_world_size())
elif args.context_parallel_algo == 'hybrid_cp_algo':
self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, args.ulysses_degree_in_cp)
else:
self.num_attention_heads_per_partition_per_cp = self.num_attention_heads_per_partition
if args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
in_hybrid_mode = False
if get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None:
in_hybrid_mode = True
if not in_hybrid_mode:
cp_group = mpu.get_context_parallel_group()
cp_size = mpu.get_context_parallel_world_size()
rank = mpu.get_context_parallel_rank()
cp_global_ranks = mpu.get_context_parallel_global_ranks()
else:
cp_group = get_context_parallel_group_for_hybrid_ring()
cp_size = get_context_parallel_for_hybrid_ring_world_size()
rank = get_context_parallel_for_hybrid_ring_rank()
cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
self.pse = None
self.pse_type = 1
self.cp_para = dict()
self.cp_para['causal'] = args.attention_mask_type == 'causal'
self.cp_para['cp_group'] = cp_group
self.cp_para['cp_size'] = cp_size
self.cp_para['rank'] = rank
self.cp_para['cp_global_ranks'] = cp_global_ranks
self.cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap() \
if args.use_cp_send_recv_overlap else None
self.cp_para['pse'] = self.pse
self.cp_para['pse_type'] = self.pse_type
self.cp_para['cp_inner_ranks'] = get_ring_ranks_for_intra_window()
self.cp_para['cp_outer_ranks'] = get_ring_ranks_for_inter_window_kv()
self.cp_para['cp_dkv_outer_ranks'] = get_ring_ranks_for_inter_window_dkv()
self.cp_para['cp_group_for_intra_window'] = get_ring_group_for_intra_window()
self.cp_para['cp_group_for_intra_window_send_recv_overlap'] = get_ring_group_for_intra_window_send_recv_overlap()
def _sparse_1d(self, x, frame, height, width):
"""
require the shape of (ntokens x batch_size x dim)
"""
_len = x.shape[0]
if _len != frame * height * width:
raise ValueError("shape mismatched.")
pad_len = 0
if _len % (self.sparse_n * self.sparse_n) != 0:
pad_len = self.sparse_n * self.sparse_n - _len % (self.sparse_n * self.sparse_n)
if pad_len != 0:
x = F.pad(x, (0, 0, 0, 0, 0, pad_len))
if not self.sparse_group:
x = rearrange(x, '(g k) b d -> g (k b) d', k=self.sparse_n)
else:
x = rearrange(x, '(n m k) b d -> (n k) (m b) d', m=self.sparse_n, k=self.sparse_n)
return x, pad_len
def _reverse_sparse_1d(self, x, frame, height, width):
"""
require the shape of (ntokens x batch_size x dim)
"""
if not self.sparse_group:
x = rearrange(x, 'g (k b) d -> (g k) b d', k=self.sparse_n)
else:
x = rearrange(x, '(n k) (m b) d -> (n m k) b d', m=self.sparse_n, k=self.sparse_n)
x = x[:frame * height * width, :, :]
return x
def _sparse_1d_kv(self, x):
"""
require the shape of (ntokens x batch_size x dim)
"""
x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n)
return x
def function_before_core_attention(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
frames: int = 8,
height: int = 16,
width: int = 16,
input_layout: str = "sbh",
rotary_pos_emb: Optional[torch.Tensor] = None,
):
args = get_args()
q, k, v = super().function_before_core_attention(query, key, input_layout, rotary_pos_emb=rotary_pos_emb)
total_frames = frames
if mpu.get_context_parallel_world_size() > 1:
if args.context_parallel_algo == 'ulysses_cp_algo':
cp_group = mpu.get_context_parallel_group()
total_frames = frames * mpu.get_context_parallel_world_size()
q = mapping.all_to_all(q, cp_group, scatter_dim=2, gather_dim=0)
k = mapping.all_to_all(k, cp_group, scatter_dim=2, gather_dim=0)
v = mapping.all_to_all(v, cp_group, scatter_dim=2, gather_dim=0)
if args.context_parallel_algo == 'hybrid_cp_algo':
cp_group = get_context_parallel_group_for_hybrid_ulysses()
total_frames = frames * args.ulysses_degree_in_cp
q = mapping.all_to_all(q, cp_group, scatter_dim=2, gather_dim=0)
k = mapping.all_to_all(k, cp_group, scatter_dim=2, gather_dim=0)
v = mapping.all_to_all(v, cp_group, scatter_dim=2, gather_dim=0)
batch_size = q.shape[1]
self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, mpu.get_context_parallel_world_size())
q = q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
k = k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
v = v.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
pad_len = None
if self.sparse1d:
q, pad_len = self._sparse_1d(q, total_frames, height, width)
if self.attention_type == AttnType.cross_attn:
k = self._sparse_1d_kv(k)
v = self._sparse_1d_kv(v)
else:
try:
k, pad_len = self._sparse_1d(k, total_frames, height, width)
except Exception as e:
print(f"An error occurred: {e}")
k, pad_len = None, None
try:
v, pad_len = self._sparse_1d(v, total_frames, height, width)
except Exception as e:
print(f"An error occurred: {e}")
v, pad_len = None, None
return q, k, v
def function_after_core_attention(
self,
out,
output_layout="sbh",
frames=None,
height=None,
width=None,
dtype=torch.bfloat16
):
args = get_args()
total_frames = frames
if mpu.get_context_parallel_world_size() > 1:
if args.context_parallel_algo == 'ulysses_cp_algo':
cp_group = mpu.get_context_parallel_group()
total_frames = frames * mpu.get_context_parallel_world_size()
if args.context_parallel_algo == 'hybrid_cp_algo':
cp_group = get_context_parallel_group_for_hybrid_ulysses()
total_frames = frames * args.ulysses_degree_in_cp
if self.sparse1d:
out = self._reverse_sparse_1d(out, total_frames, height, width)
if mpu.get_context_parallel_world_size() > 1:
if args.context_parallel_algo == 'ulysses_cp_algo':
cp_group = mpu.get_context_parallel_group()
out = mapping.all_to_all(out, cp_group, scatter_dim=0, gather_dim=2)
elif args.context_parallel_algo == 'hybrid_cp_algo':
cp_group = get_context_parallel_group_for_hybrid_ulysses()
out = mapping.all_to_all(out, cp_group, scatter_dim=0, gather_dim=2)
out = out.to(dtype)
out = super().function_after_core_attention(out, output_layout)
return out
def function_core_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
args = get_args()
batch_size = query.shape[1]
if mask is not None and args.context_parallel_algo not in ['megatron_cp_algo', 'hybrid_cp_algo']:
mask = mask.view(batch_size, 1, -1, mask.shape[-1])
if mpu.get_context_parallel_world_size() > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']:
scale = 1.0 / math.sqrt(self.head_dim)
head_num = self.num_attention_heads_per_partition_per_cp
cp_para = self.cp_para
out = ringattn_context_parallel(query, key, value, head_num, cp_para, scale, mask)
else:
out = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_attention_heads_per_partition_per_cp,
atten_mask=mask,
input_layout="SBH",
scale=1 / math.sqrt(self.head_dim)
)[0]
return out
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
frames: int = 8,
height: int = 16,
width: int = 16,
input_layout: str = "sbh",
rotary_pos_emb: Optional[torch.Tensor] = None,
):
q, k, v = self.function_before_core_attention(query, key, frames, height, width, input_layout, rotary_pos_emb)
out = self.function_core_attention(q, k, v, mask)
out = self.function_after_core_attention(out, frames=frames, height=height, width=width, dtype=query.dtype)
return out
class MultiHeadSparseMMAttentionSBH(MultiHeadSparseAttentionSBH):
"""
A multi-head attention layer for both self-atten and cross-atten, layout "SBH", MMdit.
"""
def __init__(
self,
query_dim: int,
key_dim: int,
num_heads: int,
head_dim: int,
added_kv_proj_dim: int = None,
dropout: float = 0.0,
proj_qkv_bias: bool = False,
proj_out_bias: bool = True,
sparse1d: bool = False,
sparse_n: int = None,
sparse_group: bool = None,
is_cross_attn: bool = False,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
eps: float = 1e-5,
elementwise_affine: bool = True,
):
proj_q_bias = proj_k_bias = proj_v_bias = proj_qkv_bias
super().__init__(
query_dim=query_dim,
key_dim=key_dim,
num_attention_heads=num_heads,
hidden_size=num_heads * head_dim,
proj_q_bias=proj_q_bias,
proj_k_bias=proj_k_bias,
proj_v_bias=proj_v_bias,
proj_out_bias=proj_out_bias,
dropout=dropout,
sparse1d=sparse1d,
sparse_n=sparse_n,
sparse_group=sparse_group,
elementwise_affine=elementwise_affine,
)
self.head_dim = head_dim
self.inner_dim = num_heads * self.head_dim
if qk_norm is None:
self.norm_proj_q = None
self.norm_proj_k = None
elif qk_norm == "fp32_layer_norm":
self.norm_proj_q = FP32LayerNorm(head_dim, eps=eps, sequence_parallel=True)
self.norm_proj_k = FP32LayerNorm(head_dim, eps=eps, sequence_parallel=True)
elif qk_norm == "rms_norm":
self.norm_proj_q = RMSNorm(head_dim, eps=eps, sequence_parallel=True)
self.norm_proj_k = RMSNorm(head_dim, eps=eps, sequence_parallel=True)
else:
raise ValueError(f"Unsupported qk_norm: {qk_norm}")
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
self.norm_added_proj_q = FP32LayerNorm(head_dim, eps=eps, sequence_parallel=True)
self.norm_added_proj_k = FP32LayerNorm(head_dim, eps=eps, sequence_parallel=True)
elif qk_norm == "rms_norm":
self.norm_added_proj_q = RMSNorm(head_dim, eps=eps, sequence_parallel=True)
self.norm_added_proj_k = RMSNorm(head_dim, eps=eps, sequence_parallel=True)
else:
raise ValueError(f"Unsupported qk_norm: {qk_norm}")
args = get_args()
config = core_transformer_config_from_args(args)
self.context_pre_only = context_pre_only
self.added_kv_proj_dim = added_kv_proj_dim
if self.added_kv_proj_dim is not None:
self.added_proj_k = tensor_parallel.ColumnParallelLinear(
self.inner_dim,
self.added_kv_proj_dim,
config=config,
init_method=config.init_method,
bias=proj_k_bias,
gather_output=False
)
self.added_proj_v = tensor_parallel.ColumnParallelLinear(
self.inner_dim,
self.added_kv_proj_dim,
config=config,
init_method=config.init_method,
bias=proj_v_bias,
gather_output=False
)
if self.context_pre_only is not None:
self.added_proj_q = tensor_parallel.ColumnParallelLinear(
self.inner_dim,
self.added_kv_proj_dim,
config=config,
init_method=config.init_method,
bias=proj_q_bias,
gather_output=False
)
if not self.context_pre_only:
self.added_proj_out = tensor_parallel.RowParallelLinear(
added_kv_proj_dim,
self.inner_dim,
config=config,
init_method=config.init_method,
bias=proj_out_bias,
input_is_parallel=True,
skip_bias_add=False
)
def rotate_half(self, x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rope1d(self, tokens, cos, sin):
"""
* tokens: ntokens x batch_size x nheads x dim
"""
return (tokens * cos) + (self.rotate_half(tokens) * sin)
def apply_rotary_emb(self, tokens, video_rotary_emb):
cos_t, sin_t, cos_y, sin_y, cos_x, sin_x = video_rotary_emb
dim = tokens.shape[-1]
D_t = dim // 16 * 4
D = dim // 16 * 6
origin_dtype = tokens.dtype
t, y, x = torch.split(tokens, [D_t, D, D], dim=-1)
t = self.apply_rope1d(t, cos_t, sin_t)
y = self.apply_rope1d(y, cos_y, sin_y)
x = self.apply_rope1d(x, cos_x, sin_x)
tokens = torch.cat((t, y, x), dim=-1).to(origin_dtype)
return tokens
def _reverse_sparse_1d_enc(self, x):
"""
require the shape of (ntokens x batch_size x dim)
"""
x = rearrange(x, 's (k b) d -> s k b d', k=self.sparse_n).mean(1)
return x
def _sparse_1d_enc(self, x):
"""
require the shape of (ntokens x batch_size x dim)
"""
x = repeat(x, 's b d -> s (k b) d', k=self.sparse_n)
return x
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
frames: int,
height: int,
width: int,
attention_mask: Optional[torch.Tensor] = None,
video_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
hidden_states: The hidden states of the visual stream.
encoder_hidden_states: The hidden states of the textual stream.
frames: The frame number of video
height: The height of the video
width: The width of the video
attention_mask: The attention mask to use.
video_rotary_emb: The rotary embeddings for the video
"""
q, _ = self.proj_q(hidden_states)
k, _ = self.proj_k(hidden_states)
v, _ = self.proj_v(hidden_states)
added_q, _ = self.added_proj_q(encoder_hidden_states)
added_k, _ = self.added_proj_k(encoder_hidden_states)
added_v, _ = self.added_proj_v(encoder_hidden_states)
batch_size = q.shape[1]
batch_size = added_q.shape[1]
total_frames = frames
q = q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp, self.head_dim)
k = k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp, self.head_dim)
q = self.norm_proj_q(q)
k = self.norm_proj_k(k)
added_q = added_q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp, self.head_dim)
added_k = added_k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp, self.head_dim)
added_q = self.norm_added_proj_q(added_q)
added_k = self.norm_added_proj_k(added_k)
q = self.apply_rotary_emb(q, video_rotary_emb)
k = self.apply_rotary_emb(k, video_rotary_emb)
q = q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
k = k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
v = v.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
added_q = added_q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
added_k = added_k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
added_v = added_v.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim)
if self.sparse1d:
q, _ = self._sparse_1d(q, total_frames, height, width)
k, _ = self._sparse_1d(k, total_frames, height, width)
v, _ = self._sparse_1d(v, total_frames, height, width)
added_q = self._sparse_1d_enc(added_q)
added_k = self._sparse_1d_enc(added_k)
added_v = self._sparse_1d_enc(added_v)
fa_visual_sequence_length = q.shape[0]
fa_text_sequence_length_length = added_q.shape[0]
q = torch.cat([q, added_q], dim=0)
k = torch.cat([k, added_k], dim=0)
v = torch.cat([v, added_v], dim=0)
out = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num=self.num_attention_heads_per_partition_per_cp,
atten_mask=attention_mask,
input_layout="SBH",
scale=1 / math.sqrt(self.head_dim)
)[0]
hidden_states, encoder_hidden_states = out.split([fa_visual_sequence_length, fa_text_sequence_length_length], dim=0)
if self.sparse1d:
hidden_states = self._reverse_sparse_1d(hidden_states, total_frames, height, width)
encoder_hidden_states = self._reverse_sparse_1d_enc(encoder_hidden_states)
hidden_states, _ = self.proj_out(hidden_states)
hidden_states = self.dropout(hidden_states)
if not self.context_pre_only:
encoder_hidden_states, _ = self.added_proj_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class ParallelMultiHeadAttentionSBH(nn.Module):
"""
A multi-head context parallel attention layer for both self-attention and cross-attention, layout "SBH".
Args:
query_dim: The number of channels in the query.
key_dim: The number of channels in the key, defaults to `query_dim`.
num_heads: The number of heads to use for multi-head attention.
head_dim: The number of channels in each head.
dropout: The dropout probability to use.
proj_qkv_bias: Whether to use bias in qkv projection.
proj_out_bias: Whether to use bias in out projection.
use_rope: Whether to use rope
interpolation_scale: The scale of interpolation.
"""
def __init__(
self,
query_dim: int,
key_dim: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
proj_qkv_bias: bool = False,
proj_out_bias: bool = True,
use_rope: bool = False,
interpolation_scale: Tuple[int] = (1, 1, 1)
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = self.num_heads * self.head_dim
self.use_rope = use_rope
if self.use_rope:
self.rope = RoPE3D(interpolation_scale=interpolation_scale)
self.position_getter = PositionGetter3D(atten_layout="SBH")
args = get_args()
config = core_transformer_config_from_args(args)
self.sp_size = mpu.get_context_parallel_world_size()
self.tp_size = mpu.get_tensor_model_parallel_world_size()
self.num_attention_heads_per_partition = core.utils.divide(num_heads, self.tp_size)
self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition,
self.sp_size)
key_dim = key_dim if key_dim is not None else query_dim
self.proj_q = tensor_parallel.ColumnParallelLinear(
query_dim,
self.inner_dim,
config=config,
init_method=config.init_method,
bias=proj_qkv_bias,
gather_output=False
)
self.proj_k = tensor_parallel.ColumnParallelLinear(
key_dim,
self.inner_dim,
config=config,
init_method=config.init_method,
bias=proj_qkv_bias,
gather_output=False
)
self.proj_v = tensor_parallel.ColumnParallelLinear(
key_dim,
self.inner_dim,
config=config,
init_method=config.init_method,
bias=proj_qkv_bias,
gather_output=False
)
self.proj_out = tensor_parallel.RowParallelLinear(
self.inner_dim,
query_dim,
config=config,
init_method=config.init_method,
bias=proj_out_bias,
input_is_parallel=True,
skip_bias_add=False
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
**kwargs
) -> torch.Tensor:
"""
Args:
query: The hidden states of the query.
key: The hidden states of the key.
mask: The attention mask to use.
frames: The frame number of latents
height: The height of the frame
width: The width of the frame
**kwargs: Additional keyword arguments to pass along
"""
if len(query.shape) != 3:
raise AssertionError("Parallel attention only support SBH.")
is_cross_attention = key is not None
key = query if key is None else key
s, b, _ = query.shape
if mask is not None:
mask = mask.view(b, 1, -1, mask.shape[-1])
q, _ = self.proj_q(query)
k, _ = self.proj_k(key)
v, _ = self.proj_v(key)
q = q.view(-1, self.num_attention_heads_per_partition, self.head_dim)
k = k.view(-1, self.num_attention_heads_per_partition, self.head_dim)
v = v.view(-1, self.num_attention_heads_per_partition, self.head_dim)
sp_group = mpu.get_context_parallel_group()
q = all_to_all_SBH(q, sp_group, scatter_dim=1, gather_dim=0)
if not is_cross_attention:
k = all_to_all_SBH(k, sp_group, scatter_dim=1, gather_dim=0)
v = all_to_all_SBH(v, sp_group, scatter_dim=1, gather_dim=0)
else:
k = split_forward_gather_backward(k, sp_group, dim=1, grad_scale="down")
v = split_forward_gather_backward(v, sp_group, dim=1, grad_scale="down")
if self.use_rope:
q = q.view(-1, b, self.num_attention_heads_per_partition_per_cp, self.head_dim)
k = k.view(-1, b, self.num_attention_heads_per_partition_per_cp, self.head_dim)
if (frames is None) or (height is None) or (width is None):
raise ValueError("frames, height and width can not be none when use_rope")
pos_thw = self.position_getter(b, t=frames, h=height, w=width, device=query.device)
q = self.rope(q, pos_thw)
k = self.rope(k, pos_thw)
q = q.view(-1, b, self.num_attention_heads_per_partition_per_cp * self.head_dim)
k = k.view(-1, b, self.num_attention_heads_per_partition_per_cp * self.head_dim)
v = v.view(-1, b, self.num_attention_heads_per_partition_per_cp * self.head_dim)
out = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num=self.num_attention_heads_per_partition_per_cp,
atten_mask=mask,
input_layout="SBH",
scale=1 / math.sqrt(self.head_dim)
)[0]
out = out.view(-1, self.num_attention_heads_per_partition_per_cp, self.head_dim)
out = all_to_all_SBH(out, sp_group, scatter_dim=0, gather_dim=1)
out = out.view(-1, b, self.num_attention_heads_per_partition * self.head_dim)
out, _ = self.proj_out(out)
out = self.dropout(out)
return out
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
rope=None,
qk_norm_legacy: bool = False,
) -> None:
super().__init__()
if dim % num_heads != 0:
raise AssertionError(
"dim (%d) must be divisible by num_heads (%d)" % (dim, num_heads)
)
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.enable_flashattn = enable_flashattn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.qk_norm_legacy = qk_norm_legacy
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = False
if rope is not None:
self.rope = True
self.rotary_emb = rope
def npu_spatial_attention(self, qkv: torch.Tensor) -> torch.Tensor:
B, N, _ = qkv.shape
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
x = torch_npu.npu_fusion_attention(
q, k, v, self.num_heads, input_layout="BSND",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1. - self.attn_drop.p if self.training else 1.,
sync=False,
inner_precise=0
)[0]
x = x.view(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def npu_temporal_attention(self, qkv: torch.Tensor) -> torch.Tensor:
B, N, _ = qkv.shape
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.qk_norm_legacy:
q, k = self.rotary_emb(q), self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
else:
q, k = self.q_norm(q), self.k_norm(k)
q, k = self.rotary_emb(q), self.rotary_emb(k)
x = torch_npu.npu_fusion_attention(
q, k, v, self.num_heads, input_layout="BNSD",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1. - self.attn_drop.p if self.training else 1.,
sync=False,
inner_precise=0,
)[0]
x = x.transpose(1, 2)
x = x.reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
enable_flashattn = self.enable_flashattn
qkv = self.qkv(x)
if enable_flashattn:
if qkv.dtype in [torch.float16, torch.bfloat16]:
if self.rope:
return self.npu_temporal_attention(qkv)
else:
return self.npu_spatial_attention(qkv)
else:
raise ValueError("The dtype of x must be torch.float16 or torch.bfloat16, got torch.float32 instead.")
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.qk_norm_legacy:
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
q, k = self.q_norm(q), self.k_norm(k)
else:
q, k = self.q_norm(q), self.k_norm(k)
if self.rope:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype)
attn = self.attn_drop(attn)
x = attn @ v
x_output_shape = (B, N, C)
if not enable_flashattn:
x = x.transpose(1, 2)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SeqParallelAttention(Attention):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
enable_flashattn: bool = False,
) -> None:
super().__init__(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
enable_flashattn=enable_flashattn,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = (
x.shape
)
qkv = self.qkv(x)
qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.view(qkv_shape)
sp_group = mpu.get_context_parallel_group()
qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1)
if self.enable_flashattn:
qkv_permute_shape = (2, 0, 1, 3, 4)
else:
qkv_permute_shape = (2, 0, 3, 1, 4)
qkv = qkv.permute(qkv_permute_shape)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.enable_flashattn and q.dtype in [torch.float16, torch.bfloat16]:
x = torch_npu.npu_fusion_attention(
q,
k,
v,
q.shape[-2],
input_layout="BSND",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - self.attn_drop.p if self.training else 1.0,
sync=False,
inner_precise=0,
)[0]
else:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype)
attn = self.attn_drop(attn)
x = attn @ v
if not self.enable_flashattn:
x = x.transpose(1, 2)
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
x_output_shape = (B, N, C)
x = x.reshape(x_output_shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
super().__init__()
if d_model % num_heads != 0:
raise AssertionError(
"d_model (%d) must be divisible by num_heads (%d)"
% (d_model, num_heads)
)
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_linear = nn.Linear(d_model, d_model)
self.kv_linear = nn.Linear(d_model, d_model * 2)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(d_model, d_model)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, cond, mask=None):
B, N, C = x.shape
if x.dtype not in [torch.float16, torch.bfloat16]:
raise AssertionError("QKV's dtype must be in bf16 or fp16")
q = self.q_linear(x).view(-1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(-1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(1)
actual_seq_qlen = []
actual_seq_kvlen = []
if mask is not None:
ans = 0
for _ in range(B):
ans += N
actual_seq_qlen.append(ans)
ans = 0
for m in mask:
ans += m
actual_seq_kvlen.append(ans)
x = torch_npu.npu_fusion_attention(
q,
k,
v,
self.num_heads,
input_layout="TND",
pse=None,
scale=1.0 / math.sqrt(self.head_dim),
pre_tockens=65536,
next_tockens=65536,
actual_seq_qlen=tuple(actual_seq_qlen),
actual_seq_kvlen=tuple(actual_seq_kvlen),
keep_prob=1.0 - self.attn_drop.p,
sparse_mode=0,
)[0]
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention):
def __init__(
self,
d_model,
num_heads,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__(
d_model=d_model,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
def forward(self, x, cond, mask=None):
sp_group = mpu.get_context_parallel_group()
sp_size = mpu.get_context_parallel_world_size()
B, SUB_N, C = x.shape
N = SUB_N * sp_size
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
k, v = kv.unbind(2)
q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1)
k = split_forward_gather_backward(
k, mpu.get_context_parallel_group(), dim=2, grad_scale="down"
)
v = split_forward_gather_backward(
v, mpu.get_context_parallel_group(), dim=2, grad_scale="down"
)
if x.dtype not in [torch.float16, torch.bfloat16]:
raise AssertionError("QKV's dtype must be in bf16 or fp16")
q = q.view(-1, self.num_heads // sp_size, self.head_dim)
k = k.view(-1, self.num_heads // sp_size, self.head_dim)
v = v.view(-1, self.num_heads // sp_size, self.head_dim)
actual_seq_qlen = []
actual_seq_kvlen = []
if mask is not None:
ans = 0
for _ in range(B):
ans += N
actual_seq_qlen.append(ans)
ans = 0
for m in mask:
ans += m
actual_seq_kvlen.append(ans)
x = torch_npu.npu_fusion_attention(
q,
k,
v,
q.shape[-2],
input_layout="TND",
pse=None,
scale=1.0 / math.sqrt(self.head_dim),
pre_tockens=65536,
next_tockens=65536,
actual_seq_qlen=tuple(actual_seq_qlen),
actual_seq_kvlen=tuple(actual_seq_kvlen),
keep_prob=1.0 - self.attn_drop.p,
sparse_mode=0,
)[0]
x = x.view(B, -1, self.num_heads // sp_size, self.head_dim)
x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2)
x = x.view(B, -1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Conv2dAttnBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
num_groups=32,
eps=1e-6,
kernel_size=1,
stride=1,
padding=0,
affine=True,
use_sdp_attention=False,
):
super().__init__()
self.use_sdp_attention = use_sdp_attention
self.norm = nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=eps, affine=affine
)
self.q = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding
)
self.k = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding
)
self.v = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding
)
self.proj_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding
)
def attention_sdp(self, h_):
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def attention_manual(self, x):
y = x
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h * w)
z = torch.bmm(q, k)
z = z * (int(c) ** (-0.5))
z = torch.nn.functional.softmax(z, dim=2)
v = v.reshape(b, c, h * w)
z = z.permute(0, 2, 1)
y = torch.bmm(v, z)
y = y.reshape(b, c, h, w)
return y
@video_to_image
def forward(self, x):
if self.use_sdp_attention:
attention_output = self.attention_sdp(x)
else:
attention_output = self.attention_manual(x)
return x + self.proj_out(attention_output)
class CausalConv3dAttnBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
num_groups=32,
eps=1e-6,
affine=True,
):
super().__init__()
self.norm = nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=eps, affine=affine
)
self.q = CausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.k = CausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.v = CausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.proj_out = CausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
def forward(self, x):
y = x
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
b, c, t, h, w = q.shape
q = torch_npu.npu_confusion_transpose(
q, (0, 2, 1, 3, 4), (b * t, c, h * w), True
)
q = q.permute(0, 2, 1)
k = torch_npu.npu_confusion_transpose(
k, (0, 2, 1, 3, 4), (b * t, c, h * w), True
)
z = torch.bmm(q, k)
z = z * (int(c) ** (-0.5))
z = torch.nn.functional.softmax(z, dim=2)
v = torch_npu.npu_confusion_transpose(v, (0, 2, 1, 3, 4), (b * t, c, h * w), True)
z = z.permute(0, 2, 1)
y = torch.bmm(v, z)
y = torch_npu.npu_confusion_transpose(y, (0, 2, 1, 3, 4), (b, t, c, h, w), False)
y = self.proj_out(y)
return x + y
@dataclass
class AttentionParams:
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
input_layout: str
head_num: int = None
atten_mask: torch.Tensor = None
class WfCausalConv3dAttnBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
num_groups=32,
eps=1e-6,
affine=True,
norm_type="groupnorm",
):
super().__init__()
self.norm = normalize(in_channels, num_groups, eps, affine, norm_type=norm_type)
self.q = WfCausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.k = WfCausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.v = WfCausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
self.proj_out = WfCausalConv3d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, t, h, w = q.shape
q = q.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
k = k.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
v = v.permute(0, 2, 3, 4, 1).reshape(b * t, h * w, c).contiguous()
params = AttentionParams(
query=q,
key=k,
value=v,
input_layout="BSH",
head_num=1,
)
attn_output = self.run_attention(params, head_dim=c, enable_FA=c <= 512)
attn_output = attn_output.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3)
h_ = self.proj_out(attn_output)
return x + h_
def run_attention(self, params: AttentionParams, head_dim, enable_FA=True):
if enable_FA:
hidden_states = torch_npu.npu_fusion_attention(params.query, params.key, params.value,
head_num=params.head_num,
atten_mask=None,
input_layout=params.input_layout,
scale=1 / math.sqrt(head_dim))[0]
else:
hidden_states = self.scaled_dot_product_attention(params,
scale=1 / math.sqrt(head_dim))
return hidden_states
def scaled_dot_product_attention(self, params: AttentionParams, scale=None, dropout_p=0.0, is_causal=False) -> torch.Tensor:
def trans_tensor_shape(x, layout, head_num):
if layout == "BSH":
batch = x.shape[0]
x = x.view(batch, -1, head_num, x.shape[-1] // head_num).transpose(1, 2).contiguous()
elif layout == "SBH":
batch = x.shape[1]
x = x.view(-1, batch * head_num, x.shape[-1] // head_num).transpose(0, 1).contiguous()
x = x.view(batch, head_num, -1, x.shape[-1])
return x
query = trans_tensor_shape(params.query, params.input_layout, params.head_num)
key = trans_tensor_shape(params.key, params.input_layout, params.head_num)
value = trans_tensor_shape(params.value, params. input_layout, params.head_num)
attn_weight = query @ key.transpose(-2, -1) * scale
attn_bias = torch.zeros_like(attn_weight, dtype=query.dtype, device=query.device)
if is_causal:
if params.atten_mask is not None:
raise ValueError("atten_mask should be None when is_causal is True")
temp_mask = torch.zeros_like(attn_weight, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), -10000.0)
attn_bias.to(query.dtype)
if params.atten_mask is not None and self.enable_FA and params.atten_mask.dtype == torch.bool:
raise ValueError("attention_mask must not be bool type when use this function")
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
output = attn_weight @ value
if params.input_layout == "BSH":
output = output.transpose(1, 2).contiguous().view(output.shape[0], -1, params.head_num * output.shape[-1])
else:
output = output.view(output.shape[0] * params.head_num, -1, output.shape[-1]).transpose(0, 1).contiguous()
output = output.view(output.shape[0], -1, params.head_num * output.shape[-1])
return output
class WhisperAttention(nn.Module):
"""
A multi-head attention layer for both self-atten and cross-atten, layout "BSH".
Args:
query_dim: The number of channels in the query.
key_dim: The number of channels in the key, defaults to `query_dim`.
num_heads: The number of heads to use for multi-head attention.
head_dim: The number of channels in each head.
dropout: The dropout probability to use.
proj_qkv_bias: Whether to use bias in qkv projection.
proj_out_bias: Whether to use bias in out projection.
use_rope: Whether to use rope
interpolation_scale: The scale of interpolation.
"""
def __init__(
self,
query_dim: int,
key_dim: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
proj_qv_bias: bool = False,
proj_out_bias: bool = True,
interpolation_scale: Tuple[int] = (1, 1, 1),
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = self.num_heads * self.head_dim
key_dim = key_dim if key_dim is not None else query_dim
self.proj_q = nn.Linear(query_dim, self.inner_dim, bias=proj_qv_bias)
self.proj_k = nn.Linear(key_dim, self.inner_dim, bias=False)
self.proj_v = nn.Linear(key_dim, self.inner_dim, bias=proj_qv_bias)
self.proj_out = nn.Linear(self.inner_dim, query_dim, bias=proj_out_bias)
self.dropout = nn.Dropout(dropout)
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
query: The hidden states of the query.
key: The hidden states of the key.
mask: The attention mask to use.
**kwargs: Additional keyword arguments to pass along
"""
input_ndim = query.ndim
if input_ndim == 4:
b, c, h, w = query.shape
query = query.view(b, c, h * w).transpose(1, 2)
key = query if key is None else key
b, _, _ = query.shape
if mask is not None:
mask = mask.view(b, 1, -1, mask.shape[-1])
q = self.proj_q(query)
k = self.proj_k(key)
v = self.proj_v(key)
q = q.view(b, -1, self.inner_dim)
k = k.view(b, -1, self.inner_dim)
out = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num=self.num_heads,
atten_mask=mask,
input_layout="BSH",
scale=1 / math.sqrt(self.head_dim)
)[0]
out = self.proj_out(out)
out = self.dropout(out)
if input_ndim == 4:
out = out.transpose(-1, -2).reshape(b, c, h, w)
return out