import math
import torch
import torch_npu
import numpy as np
from stepvideo.parallel import (
get_llm_tensor_model_parallel_rank,
get_llm_tensor_model_parallel_world_size
)
DTYPE_FP16_MIN = float(np.finfo(np.float16).min)
def _get_alibi_slopes(n_heads):
n = 2**math.floor(math.log2(n_heads))
m0 = torch.tensor(2.0**(-8.0 / n), dtype=torch.float32).to("cpu")
slopes = torch.pow(m0, torch.arange(1, n + 1, dtype=torch.float32).to("cpu"))
if n < n_heads:
m1 = torch.tensor(2.0**(-4.0 / n), dtype=torch.float32).to("cpu")
mm = torch.pow(m1, torch.arange(1, 1 + 2 * (n_heads - n), 2, dtype=torch.float32).to("cpu"))
slopes = torch.cat([slopes, mm])
return slopes
def _get_mask(seq_len, b, n):
slopes = _get_alibi_slopes(n)
tril = torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.bool)).to(torch.int32)
bias_row = torch.arange(seq_len).view(1, -1)
bias_cols = torch.arange(seq_len).view(-1, 1)
bias = -torch.sqrt(bias_cols - bias_row)
bias = bias.view(1, seq_len, seq_len) * slopes.view(-1, 1, 1)
bias = bias.masked_fill(tril == 0, DTYPE_FP16_MIN)
return bias
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
attention_dropout=0.0,
):
super().__init__()
self.dropout_p = attention_dropout
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
if cu_seqlens is None:
tp_rank = get_llm_tensor_model_parallel_rank()
tp_size = get_llm_tensor_model_parallel_world_size()
alibi_mask = _get_mask(q.size(1), q.size(0), q.size(2) * tp_size)
alibi_mask = alibi_mask[:, (tp_rank * q.size(2)):(tp_rank * q.size(2) + q.size(2)), :, :].to(q.dtype).to(q.device)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=alibi_mask)
output = output.transpose(1, 2)
else:
raise ValueError('cu_seqlens is not supported!')
return output