"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import torch_npu
import triton
import triton.language as tl
import triton.language.extra.cann.extension as extension
DEVICE = "npu"
@triton.jit
def _attn_fwd_inner(acc_ptr, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m, qk_scale,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
N_CTX: tl.constexpr, fp8_v: tl.constexpr):
if STAGE == 1:
tl.static_assert(BLOCK_M >= BLOCK_N)
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
tl.static_assert(BLOCK_M >= BLOCK_N)
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
else:
lo, hi = 0, N_CTX
K_block_ptr = tl.advance(K_block_ptr, (lo, 0))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
row = tl.arange(0, BLOCK_M)[:, None]
col_head_dim = tl.arange(0, HEAD_DIM)[None, :]
block2d_acc = row * HEAD_DIM + col_head_dim
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k = tl.load(K_block_ptr)
trans_k = tl.trans(k)
qk = tl.dot(q, trans_k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
qk = qk * qk_scale
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp(qk)
if fp8_v:
p_cast = p.to(tl.float8e5)
else:
p_cast = p.to(k.dtype)
v = tl.load(V_block_ptr)
pv = tl.dot(p_cast, v)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp(m_i - m_ij)
l_i = l_i * alpha + l_ij
if HEAD_DIM < 256:
acc_ptr = acc_ptr * alpha[:, None]
acc_ptr = tl.dot(p_cast, v, acc_ptr)
else:
acc = tl.load(acc_ptr + block2d_acc)
for i in range(4):
offset = i * (BLOCK_M // 4)
acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1))
alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1])
pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1))
acc_i = acc_i * alpha_i[:, None] + pv_i
acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1))
tl.store(acc_ptr + block2d_acc, acc)
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0))
return acc_ptr, l_i, m_i
@triton.jit
def _attn_fwd(Q, K, V, M, Out, acc, sm_scale,
stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr,
stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr,
stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr,
stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr,
Z: tl.constexpr, H: tl.constexpr,
N_CTX: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr
):
NUM_BLOCKS_M = N_CTX // BLOCK_M
NUM_BLOCKS = NUM_BLOCKS_M * Z * H
pid = tl.program_id(0)
for block_idx in range(pid, NUM_BLOCKS, 20):
task_hz_idx = block_idx // NUM_BLOCKS_M
task_m_idx = block_idx % NUM_BLOCKS_M
off_z = task_hz_idx // H
off_h = task_hz_idx % H
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(task_m_idx * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_kn, stride_kk),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(task_m_idx * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
if HEAD_DIM < 256:
acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
else:
acc_offset = (
off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM
+ off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM
+ task_m_idx * BLOCK_M * HEAD_DIM
)
acc_ptr = acc + acc_offset
q = tl.load(Q_block_ptr)
if STAGE & 1:
acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr,
task_m_idx, sm_scale,
BLOCK_M, HEAD_DIM, BLOCK_N,
4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5
)
if STAGE & 2:
acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr,
task_m_idx, sm_scale,
BLOCK_M, HEAD_DIM, BLOCK_N,
2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5
)
m_i += tl.math.log(l_i)
if HEAD_DIM < 256:
accumulator = acc_ptr / l_i[:, None]
else:
row = tl.arange(0, BLOCK_M)[:, None]
col_head_dim = tl.arange(0, HEAD_DIM)[None, :]
block2d_acc = row * HEAD_DIM + col_head_dim
accumulator = tl.load(acc_ptr + block2d_acc)
accumulator = accumulator / l_i[:, None]
m_ptrs = M + task_hz_idx * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, accumulator.to(Out.type.element_ty))
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, BM, BN):
"""
Forward computation interface:
Args:
ctx: Context object
q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM]
k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM]
v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM]
causal: Whether to enable causal attention
sm_scale: Scaling factor for QK product
BM: Q block size (BLOCK_M)
BN: K/V block size (BLOCK_N)
Returns:
o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM]
"""
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
out = torch.empty_like(q)
stage = 3 if causal else 1
extra_kern_args = {}
num_cores = 20
acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[(num_cores,)](
q, k, v, M, out, acc, sm_scale,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
q.shape[0], q.shape[1], N_CTX=q.shape[2],
HEAD_DIM=HEAD_DIM_K,
BLOCK_M=BM,
BLOCK_N=BN,
STAGE=stage,
**extra_kern_args)
ctx.save_for_backward(q, k, v, out, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return out
attention = _attention.apply
@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [
(1, 1, 128, 128, False, torch.float16, 32, 128),
(1, 1, 128, 128, False, torch.bfloat16, 64, 128),
(1, 2, 256, 256, False, torch.bfloat16, 32, 256),
(2, 2, 128, 256, False, torch.float16, 64, 128),
(4, 32, 64, 64, False, torch.float16, 32, 64),
(4, 32, 1024, 64, False, torch.bfloat16, 64, 128),
(4, 32, 4096, 64, False, torch.float16, 128, 128),
])
def test_attention_fused(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN):
if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0:
pytest.skip("Skipping non-divisible case")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()
sm_scale = 0.5
tri_out = attention(q, k, v, causal, sm_scale, BM, BN)
ref_out = torch_npu.npu_fusion_attention(
q, k, v, H,
padding_mask=None,
atten_mask=None,
scale=sm_scale,
keep_prob=1.0,
input_layout="BNSD",
pre_tockens=65535,
next_tockens=65535,
sparse_mode=0,
)[0]
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True)