""" Triton Implementation of the flex_attention Kernel"""
import logging
import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, Dict, Optional, Union
import sympy
import torch
from torch._inductor.virtualized import V, ops
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map
from torch._inductor import config
from torch._inductor.ir import (
Buffer,
ComputedBuffer,
ExternKernel,
FixedLayout,
FlexibleLayout,
get_fill_order,
InputBuffer,
IRNode,
MutationLayoutSHOULDREMOVE,
Scatter,
StorageBox,
Subgraph,
TensorBox,
)
from torch._inductor.lowering import (
_full,
check_and_broadcast_indices,
empty,
empty_strided,
expand,
index_output_size_and_inner_fn,
lowerings,
register_lowering,
to_dtype,
)
from torch._inductor.select_algorithm import autotune_select_algorithm, realize_inputs, TritonTemplate
from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel
from torch._inductor.kernel.flex_attention import (
Mode,
lower_cpu,
maybe_realize,
is_power_of_2,
next_power_of_two,
construct_strides,
create_placeholder,
set_head_dim_values,
create_indices_fake,
flex_attention_grid,
infer_dense_strides,
validate_joint_graph,
process_joint_outputs,
build_subgraph_buffer,
flex_attention_backward_grid,
create_num_blocks_fake_generator
)
from torch_npu._inductor.select_algorithm import NPUTritonTemplate
aten = torch.ops.aten
Expr = sympy.Expr
def _get_flex_attention_additional_lowerings():
"""
Get additional lowerings for flex_attention subgraph.
These lowerings are used to allow index and bitwise operations to be lowered
as pointwise ops instead of fallback in the mask_mod subgraph.
"""
from torch._inductor.lowering import make_pointwise, index_impl
from torch._inductor.subgraph_lowering import PointwiseSubgraphLowering
additional_lowerings = {}
def index_pointwise(x, indices):
return index_impl(x, indices, check=True)
additional_lowerings[aten.index] = index_pointwise
additional_lowerings[aten.index.Tensor] = index_pointwise
bitwise_and_fn = make_pointwise(ops.bitwise_and)
def bitwise_and_tensor(a, b):
return bitwise_and_fn(a, b)
bitwise_or_fn = make_pointwise(ops.bitwise_or)
def bitwise_or_tensor(a, b):
return bitwise_or_fn(a, b)
bitwise_not_fn = make_pointwise(ops.bitwise_not)
def bitwise_not_default(a):
return bitwise_not_fn(a)
additional_lowerings[aten.bitwise_and.Tensor] = bitwise_and_tensor
additional_lowerings[aten.bitwise_or.Tensor] = bitwise_or_tensor
additional_lowerings[aten.bitwise_not.default] = bitwise_not_default
return additional_lowerings
def _build_subgraph_buffer_with_additional_lowerings(args, subgraph):
"""
Build subgraph buffer with additional lowerings for flex_attention.
This function creates a PointwiseSubgraphLowering with additional_lowerings
to handle index and bitwise operations as pointwise ops.
"""
from torch._inductor.subgraph_lowering import PointwiseSubgraphLowering
additional_lowerings = _get_flex_attention_additional_lowerings()
pw_subgraph = PointwiseSubgraphLowering(
subgraph.graph_module,
root_graph_lowering=V.graph,
additional_lowerings=additional_lowerings,
)
with V.set_graph_handler(pw_subgraph):
pw_subgraph.run(*args)
def convert_output_node_to_buffer(output_buffer):
from torch._inductor.ir import ComputedBuffer, FlexibleLayout, StorageBox
if output_buffer is None:
return None
if isinstance(output_buffer, ComputedBuffer):
return output_buffer
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data,
)
return subgraph_buffer
return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs)
def _use_flex_decoding(query, kernel_options):
force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False)
short_query_length = V.graph.sizevars.evaluate_expr(
sympy.Lt(query.get_size()[-2], 128)
)
non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0))
static_batch = isinstance(query.get_size()[0], (int, sympy.Integer))
static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer))
return not force_flex and short_query_length and static_batch and static_num_heads
def _validate_device(query, key, value):
return
compute_next_offset_func = r"""
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
"""
get_bounded_indices_func = r"""
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
"""
load_checked_block = r"""
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
"""
load_checked_2d = r"""
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
"""
compute_flex_attention = r"""
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
#
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
#
# (Modifiable) Performance tuning options
# BLOCK_M: The thread block size across the seqlen dim of Q.
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
# contiguous? If so, we don't need to do an indirect jump for every block
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_zq = tl.program_id(1) // HQ
off_hq = tl.program_id(1) % HQ
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_zq * stride_qz + off_hq * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(Q_LEN, QK_HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(q_start * BLOCK_M, 0),
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
order=(1, 0)
)
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# [Note] Handle fully masked out rows:
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1) // HQ
idx_hq = tl.program_id(1) % HQ
idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
if OUTPUT_LOGSUMEXP:
off_hz = tl.program_id(1)
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
"""
compute_forward_inner = r"""
@triton.jit
def forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
# update pointers
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
offs_n = offs_n + offset
return acc, l_i, m_i
"""
compute_forward_block_mn = r"""
@triton.jit
def forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
# -- load k --
k = load_checked_block(K_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
# -- compute qk ---
qk = tl.dot(q, k)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=1,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
"""
del TritonTemplate.all_templates["flex_attention"]
del TritonTemplate.all_templates["flex_attention_backward"]
flex_attention_template = NPUTritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=compute_flex_attention
+ compute_forward_inner
+ compute_next_offset_func
+ compute_forward_block_mn
+ load_checked_block
+ get_bounded_indices_func,
)
def _get_npu_config(query, mode: Mode) -> tuple[int, int, int, int]:
dtype = query.get_dtype()
head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1])
fwd_config = None
if mode == Mode.fwd:
if head_dim <= 256:
fwd_config = (64, 64, 4, 3)
else:
if dtype == torch.float32:
fwd_config = (32, 16, 4, 3)
else:
fwd_config = (32, 32, 4, 3)
return fwd_config
else:
assert mode == Mode.bwd
bwd_config = (16, 16, 4, 1)
return bwd_config
def _get_default_config_fwd(query) -> tuple[int, int, int, int]:
return _get_npu_config(query, mode=Mode.fwd)
def _get_default_config_bwd(query) -> tuple[int, int, int, int]:
return _get_npu_config(query, mode=Mode.bwd)
flex_attention_backward_template = NPUTritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=r"""
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
HKV = {{size("K", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_zq = off_hz // HKV # q batch idx
off_hkv = off_hz % HKV # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
stride_q_idx_h = {{stride("Q_IDX", 1)}}
stride_q_idx_n = {{stride("Q_IDX", 2)}}
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
{{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
@triton.jit
def bwd_dq_inner(
{{gen_argdefs()}},
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
# pytorch
# kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
# vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# pta
kT_ptrs = K + offs_n2[:, None] * stride_kn + offs_k[None, :] * stride_kd
vT_ptrs = V + offs_n2[:, None] * stride_vn + offs_v[None, :] * stride_vd
# ===
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1)}}
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
# pta
kT = tl.trans(kT)
# ===
qk = tl.dot(q, kT)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds prior to the last loop
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
# pta
vT = tl.trans(vT)
# ===
dp = tl.dot(do, vT)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT))
return dq
@triton.jit
def bwd_dkdv_inner(
{{gen_argdefs()}},
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
# pytorch
# qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
# pta
qT_ptrs = Q + offs_m1[:, None] * stride_qm + offs_k[None, :] * stride_qd
# ===
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1) }}
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# pta
qT = tl.trans(qT)
# ===
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds prior to the last loop
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
pre_mod_scores = qkT
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qkT"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="dsT"
) | indent_except_first(1) }}
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="idx_b",
h="idx_h",
m="idx_m",
n="idx_n",
grad_score_mod="dsT"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT))
return dk, dv
"""
+ compute_next_offset_func
+ get_bounded_indices_func
+ load_checked_2d,
)
def _register_npu_inductor_flex_attention():
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
def flex_attention(
query,
key,
value,
subgraph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
):
(
_,
_,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
SPARSE_Q_BLOCK_SIZE,
SPARSE_KV_BLOCK_SIZE,
mask_graph,
) = block_mask
placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
("score", query.get_dtype()),
("b", torch.int32),
("h", torch.int32),
("m", torch.int32),
("n", torch.int32),
]
]
subgraph_buffer = _build_subgraph_buffer_with_additional_lowerings(
placeholder_inps + list(score_mod_other_buffers), subgraph
)
mask_graph_placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
("b", torch.int32),
("h", torch.int32),
("m", torch.int32),
("n", torch.int32),
]
]
mask_graph_buffer = _build_subgraph_buffer_with_additional_lowerings(
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
)
kernel_options = dict(kernel_options)
kernel_options = {
k: V.graph.sizevars.evaluate_static_shape(v)
if isinstance(v, sympy.Symbol)
else v
for k, v in kernel_options.items()
}
if _use_flex_decoding(query, kernel_options):
return create_flex_decoding_kernel(
query,
key,
value,
block_mask,
scale,
kernel_options,
subgraph_buffer,
mask_graph_buffer,
score_mod_other_buffers,
mask_mod_other_buffers,
)
(
query,
key,
value,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
) = maybe_realize(
[
query,
key,
value,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
]
)
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
B = Bq
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
kernel_options.setdefault("IS_DIVISIBLE", False)
else:
kernel_options.setdefault("IS_DIVISIBLE", True)
q_strides = query.get_stride()
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
out_size = [B, Hq, seq_len_q, v_head_dim]
fill_order = get_fill_order(query.get_stride(), V.graph.sizevars.shape_env)
out_strides = construct_strides(out_size, fill_order)
layout = FixedLayout(
query.get_device(),
query.get_dtype(),
[B, Hq, seq_len_q, v_head_dim],
stride=[sympy.sympify(s) for s in out_strides],
)
logsumexp_shape = [B, Hq, seq_len_q]
logsumexp = empty_strided(
logsumexp_shape,
None,
dtype=torch.float32,
device=query.get_device(),
)
kernel_options.setdefault("SM_SCALE", scale)
gqa_shared_heads = Hq // Hkv
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
has_full_blocks = full_kv_num_blocks is not None
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
if not has_full_blocks:
full_kv_num_blocks, full_kv_indices = (
empty(0, device=query.get_device()) for _ in range(2)
)
set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
choices: list[Any] = []
configs: list[tuple[int, int, int, int]] = []
configs.append(_get_default_config_fwd(query))
if config.max_autotune:
configs += [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
(64, 64, 4, 3),
]
if torch.version.hip:
configs = [(c[0], c[1], c[2], 1) for c in configs]
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
original_kernel_options = kernel_options.copy()
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0:
if len(configs) == 1:
raise ValueError(
f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We "
f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}."
)
continue
cur_kernel_options = original_kernel_options.copy()
for k in list(cur_kernel_options.keys()):
if k.startswith("fwd_"):
v = cur_kernel_options.pop(k)
cur_kernel_options[k[4:]] = v
if k.startswith("bwd_"):
cur_kernel_options.pop(k)
cur_kernel_options.setdefault("num_stages", num_stages)
cur_kernel_options.setdefault("num_warps", num_warps)
cur_kernel_options.setdefault("BLOCK_M", BLOCK_M)
cur_kernel_options.setdefault("BLOCK_N", BLOCK_N)
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
error = flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
],
layout=layout,
subgraphs=[
subgraph_buffer,
mask_graph_buffer,
],
mutated_inputs=[
logsumexp,
],
call_sizes=query.get_size(),
**cur_kernel_options,
)
if error is not None and len(configs) == 1:
raise error
inputs_for_autotuning = (
[
query,
key,
value,
logsumexp,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
]
+ list(score_mod_other_buffers)
+ list(mask_mod_other_buffers)
)
input_gen_fns = {
4: create_num_blocks_fake_generator(kv_indices),
5: create_indices_fake,
6: create_num_blocks_fake_generator(full_kv_indices),
7: create_indices_fake,
}
return (
autotune_select_algorithm(
"flex_attention",
choices,
inputs_for_autotuning,
layout,
input_gen_fns=input_gen_fns,
),
logsumexp,
)
@register_lowering(torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None)
def flex_attention_backward(*args, **kwargs):
(
query,
key,
value,
out,
logsumexp,
grad_out,
grad_logsumexp,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
) = args
(
_,
_,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
SPARSE_Q_BLOCK_SIZE,
SPARSE_KV_BLOCK_SIZE,
mask_graph,
) = block_mask
(
query,
key,
value,
grad_out,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
) = maybe_realize(
[
query,
key,
value,
grad_out,
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
q_num_blocks,
q_indices,
full_q_num_blocks,
full_q_indices,
]
)
device = query.get_device()
dtype = query.get_dtype()
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
kernel_options = dict(kernel_options)
kernel_options = {
k: V.graph.sizevars.evaluate_static_shape(v)
if isinstance(v, sympy.Symbol)
else v
for k, v in kernel_options.items()
}
if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0:
kernel_options.setdefault("IS_DIVISIBLE", False)
else:
kernel_options.setdefault("IS_DIVISIBLE", True)
fwd_placeholder_inps = [
create_placeholder(name, dtype, device)
for name, dtype in [
("score", dtype),
("b", torch.int32),
("h", torch.int32),
("m", torch.int32),
("n", torch.int32),
]
]
fw_subgraph_buffer = _build_subgraph_buffer_with_additional_lowerings(
fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph
)
joint_placeholder_inps = fwd_placeholder_inps + [
create_placeholder("grad_score_mod", dtype, device)
]
joint_graph.graph_module.graph.eliminate_dead_code()
validate_joint_graph(joint_graph.graph_module.graph)
all_joint_outputs = _build_subgraph_buffer_with_additional_lowerings(
joint_placeholder_inps + list(score_mod_other_buffers),
joint_graph,
)
joint_outputs = process_joint_outputs(
all_joint_outputs, len(joint_placeholder_inps)
)
mask_graph_placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
("b", torch.int32),
("h", torch.int32),
("m", torch.int32),
("n", torch.int32),
]
]
mask_graph_buffer = _build_subgraph_buffer_with_additional_lowerings(
mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph
)
mask_graph_buffer = mask_graph_buffer
key_size = [Bq, Hkv, seq_len_kv, qk_head_dim]
key_strides = infer_dense_strides(key_size, key.get_stride())
layout_broadcasted_k = FixedLayout(
key.get_device(),
key.get_dtype(),
key_size,
stride=[sympy.sympify(s) for s in key_strides],
)
grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2))
mul_delta = lowerings[aten.mul](out, grad_out)
delta = lowerings[aten.sum](mul_delta, axis=-1)
delta = lowerings[aten.sub](delta, grad_lse_exp2)
delta = ExternKernel.require_contiguous(delta)
grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
query_size = [Bq, Hq, seq_len_q, qk_head_dim]
grad_query_strides = infer_dense_strides(query_size, query.get_stride())
grad_query = empty_strided(
query_size,
stride=[sympy.sympify(s) for s in grad_query_strides],
dtype=query.get_dtype(),
device=query.get_device(),
)
value_size = [Bq, Hkv, seq_len_kv, v_head_dim]
value_strides = infer_dense_strides(value_size, value.get_stride())
broadcasted_grad_value = empty_strided(
value_size,
stride=[sympy.sympify(s) for s in value_strides],
dtype=value.get_dtype(),
device=value.get_device(),
)
kernel_options.setdefault("SM_SCALE", scale)
gqa_shared_heads = Hq // Hkv
kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads)
has_full_blocks = full_kv_num_blocks is not None
kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks)
if not has_full_blocks:
full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = (
empty(0, device=query.get_device()) for _ in range(4)
)
set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
choices: list[Any] = []
configs: list[tuple[int, int, int, int]] = []
configs.append(_get_default_config_bwd(query))
if config.max_autotune:
num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1]
configs.extend(
[
(BLOCK1, BLOCK2, w, s)
for BLOCK1 in [32, 64]
for BLOCK2 in [32, 64, 128]
for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4])
for s in num_stages_list
if BLOCK2 % BLOCK1 == 0
]
)
original_kernel_options = kernel_options.copy()
for BLOCK1, BLOCK2, num_warps, num_stages in configs:
if (
SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0
or SPARSE_Q_BLOCK_SIZE % BLOCK1 != 0
or SPARSE_KV_BLOCK_SIZE % BLOCK2 != 0
or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0
):
continue
cur_kernel_options = original_kernel_options.copy()
for k in list(cur_kernel_options.keys()):
if k.startswith("bwd_"):
v = cur_kernel_options.pop(k)
cur_kernel_options[k[4:]] = v
if k.startswith("fwd_"):
cur_kernel_options.pop(k)
cur_kernel_options.setdefault("num_warps", num_warps)
cur_kernel_options.setdefault("num_stages", num_stages)
cur_kernel_options.setdefault("BLOCK_M1", BLOCK1)
cur_kernel_options.setdefault("BLOCK_N1", BLOCK2)
cur_kernel_options.setdefault("BLOCK_M2", BLOCK2)
cur_kernel_options.setdefault("BLOCK_N2", BLOCK1)
cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE)
cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE)
flex_attention_backward_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
logsumexp,
delta,
grad_out,
grad_query,
broadcasted_grad_value,
kv_num_blocks,
kv_indices,
q_num_blocks,
q_indices,
full_kv_num_blocks,
full_kv_indices,
full_q_num_blocks,
full_q_indices,
],
layout=layout_broadcasted_k,
subgraphs=[
fw_subgraph_buffer,
joint_outputs.grad_input,
mask_graph_buffer,
joint_outputs.captured_grads_compute,
],
mutated_inputs=[
grad_query,
broadcasted_grad_value,
*joint_outputs.mutated_grads,
],
call_sizes=query.get_size() + key.get_size()[1:3],
**cur_kernel_options,
)
inputs_for_autotuning = (
[
query,
key,
value,
logsumexp,
delta,
grad_out,
grad_query,
broadcasted_grad_value,
kv_num_blocks,
kv_indices,
q_num_blocks,
q_indices,
full_kv_num_blocks,
full_kv_indices,
full_q_num_blocks,
full_q_indices,
]
+ list(score_mod_other_buffers)
+ list(mask_mod_other_buffers)
+ joint_outputs.mutated_grads
)
input_gen_fns = {
8: create_num_blocks_fake_generator(kv_indices),
9: create_indices_fake,
10: create_num_blocks_fake_generator(q_indices),
11: create_indices_fake,
12: create_num_blocks_fake_generator(full_kv_indices),
13: create_indices_fake,
14: create_num_blocks_fake_generator(full_q_indices),
15: create_indices_fake,
}
broadcasted_grad_key = autotune_select_algorithm(
"flex_attention_backward",
choices,
inputs_for_autotuning,
layout_broadcasted_k,
input_gen_fns=input_gen_fns,
)
if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)):
grad_key = broadcasted_grad_key
grad_value = broadcasted_grad_value
else:
assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. "
f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
)
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads))