"""
Benchmark driver for torch_bsa.blitz_sparse_attention.
Measures:
* latency (usec) - NPU timer
* correctness comparison with a reference PyTorch implementation
Sweeps every `(B, H, S, D, sparsity)` configuration across every entry in
`BLOCK_SHAPES`. Default is `[(128, 128), (128, 512)]` — the two sabi
granularities the kernel supports at runtime via the `block_shape` op-attr.
Each sweep row is labelled with the active block shape; narrow the list to a
single pair to benchmark only one granularity.
"""
import math
import random
import sys
import itertools
import logging
from collections import namedtuple
from typing import Callable, List, Tuple
import torch
import torch_npu
import torch_bsa
FrameWidths = namedtuple('FrameWidths', ['left_cols', 'right_cols', 'top_rows', 'bottom_rows'])
MaskSpec = namedtuple(
"MaskSpec",
["b", "h", "s_q", "s_kv", "d", "block_size_q", "block_size_kv", "sparsity", "frame"],
)
AttentionMask = namedtuple(
"AttentionMask",
["pfa_atten_mask", "bsa_atten_mask", "sabi", "sparse_mode", "scale", "pre_tokens", "next_tokens"],
)
logging.basicConfig(level=logging.INFO, format='%(message)s', stream=sys.stdout)
logger = logging.getLogger(__name__)
DEVICE = 'npu'
N_REPEATS = 10
N_WARMUP = 2
BLOCK_MASK_SEED = 1234
DTYPE = torch.bfloat16
INPUT_LAYOUT = "BNSD"
B_VALS = [1]
H_VALS = [3]
S_VALS = [118_806]
D_VALS = [128]
SPARSITY_VALS = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
BLOCK_SHAPES = [(bsq, bskv)
for bsq in (128, 256, 512, 1024)
for bskv in (128, 256, 512, 1024)]
SABI_SORTED = True
_LEFT_COLS_BY_KV = {128: 29, 256: 15, 512: 8, 1024: 4}
_TOP_ROWS_BY_Q = {128: 29, 256: 15, 512: 8, 1024: 4}
FRAMES_BY_BLOCK_SHAPE = {
(bsq, bskv): FrameWidths(
left_cols=_LEFT_COLS_BY_KV[bskv],
right_cols=1,
top_rows=_TOP_ROWS_BY_Q[bsq],
bottom_rows=1,
)
for bsq in _TOP_ROWS_BY_Q
for bskv in _LEFT_COLS_BY_KV
}
PRINT_OUTPUTS = False
PRINT_MASK = False
PRINT_BLOCK_EQUALITY = False
PRINT_HEIGHT = 128
PRINT_WIDTH = 8
RUN_REFERENCE = False
TORCH_REFERENCE = "npu_fusion_attention"
torch.set_printoptions(
threshold=100_000_000,
linewidth=400,
edgeitems=3,
precision=4,
sci_mode=False
)
def ref_blitz_sparse_attention_launcher(
torch_reference: str,
pfa_inputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, float, torch.Tensor, str],
force_dense_sm: bool,
return_lse: bool = False,
):
"""
runs a reference attention kernel, for correctness comparisons and for baseline time measurements.
torch_reference: which reference to launch
"custom" - launch our pythonic model `ref_blitz_sparse_attention_fp32`
"npu_fusion_attention" - launch `torch_npu.npu_fusion_attention` (PFA)
"npu_fused_infer_attention_score" - launch `torch_npu.npu_fused_infer_attention_score` (FIAS)
force_dense_sm - only consulted for the npu_* references:
True - apply sparse_mode=0 (dense) and don't use the provided atten_mask
False - apply sparse_mode=1 (sparse token mask) and use the provided atten_mask
return_lse - if True, return (attention_out, softmax_lse[B,H,S]). Only the FIAS
reference natively produces softmax_lse; "custom" and
"npu_fusion_attention" raise ValueError when asked for it.
If False (default), return only attention_out.
"""
q, k, v, head_num, scale, atten_mask, input_layout = pfa_inputs
if torch_reference == "custom":
if return_lse:
raise ValueError("return_lse=True is not supported by the 'custom' reference "
"(ref_blitz_sparse_attention_fp32 does not compute softmax_lse).")
return ref_blitz_sparse_attention_fp32(q, k, v, scale, atten_mask=atten_mask)
if torch_reference == "npu_fusion_attention":
if return_lse:
raise ValueError("return_lse=True is not supported by 'npu_fusion_attention'; "
"use 'npu_fused_infer_attention_score' instead.")
return torch_npu.npu_fusion_attention(q, k, v, head_num=head_num, input_layout=input_layout,
scale=scale, pre_tockens=0, next_tockens=0,
atten_mask=atten_mask,
sparse_mode=0 if force_dense_sm else 1)[0]
if torch_reference == "npu_fused_infer_attention_score":
if not force_dense_sm:
raise ValueError(
"torch_reference='npu_fused_infer_attention_score' requires "
"force_dense_sm=True. FIAS does not support our atten_mask layout "
"(sparse_mode=1 needs [B, 1, S, S]); call "
"torch_npu.npu_fused_infer_attention_score(...) directly for sparse runs."
)
if input_layout == "BNSD":
batch, s_q, s_kv = q.shape[0], q.shape[2], k.shape[2]
elif input_layout in ("BSND", "BSH"):
batch, s_q, s_kv = q.shape[0], q.shape[1], k.shape[1]
else:
raise ValueError(f"Unsupported input_layout for FIAS reference: {input_layout!r}")
actual_seq_q = [s_q] * batch
actual_seq_kv = [s_kv] * batch
out, lse = torch_npu.npu_fused_infer_attention_score(
q, k, v,
num_heads=head_num, scale=scale,
input_layout=input_layout, num_key_value_heads=head_num,
actual_seq_lengths=actual_seq_q, actual_seq_lengths_kv=actual_seq_kv,
softmax_lse_flag=return_lse,
)
if return_lse:
return out, lse.squeeze(-1)
return out
raise ValueError(
f"Unknown torch_reference value: {torch_reference!r}. "
f"Expected one of: 'custom', 'npu_fusion_attention', 'npu_fused_infer_attention_score'."
)
def _frame_exceeds_dims(frame, n_block_rows, n_block_cols):
"""True if any frame side is wider/taller than the available block grid."""
cols_overflow = max(frame.left_cols, frame.right_cols) > n_block_cols
rows_overflow = max(frame.top_rows, frame.bottom_rows) > n_block_rows
return cols_overflow or rows_overflow
def _compute_frame_sets(frame, n_block_rows, n_block_cols):
"""Return (sparse_forced_col_ids, sparse_candidate_col_ids, dense_row_ids).
`frame=None` means no border — all rows are sparse and all cols are
candidates with nothing forced. Otherwise the left/right cols are forced
in every sparse row, and the top/bottom rows are forced fully dense.
"""
if frame is None:
return set(), list(range(n_block_cols)), set()
if _frame_exceeds_dims(frame, n_block_rows, n_block_cols):
raise ValueError(
f"Frame configuration exceeds block dimensions: "
f"{frame=}, {n_block_rows=}, {n_block_cols=}"
)
sparse_forced_left = set(range(min(frame.left_cols, n_block_cols)))
sparse_forced_right = set(range(max(n_block_cols - frame.right_cols, 0), n_block_cols))
sparse_forced_col_ids = sparse_forced_left | sparse_forced_right
sparse_candidate_col_ids = sorted(set(range(n_block_cols)) - sparse_forced_col_ids)
dense_top = set(range(min(frame.top_rows, n_block_rows)))
dense_bottom = set(range(max(n_block_rows - frame.bottom_rows, 0), n_block_rows))
dense_row_ids = dense_top | dense_bottom
return sparse_forced_col_ids, sparse_candidate_col_ids, dense_row_ids
def _build_sparse_row(rng, frame_sets, n_kept_per_row, n_block_cols, pad_value):
"""Pick the cols for a single sparse row: forced frame cols + `n_kept_per_row`
random candidates, sorted, then padded out to width n_block_cols."""
forced, candidates, _ = frame_sets
if n_kept_per_row <= 0:
cols = sorted(forced)
else:
sparse_selected = rng.sample(candidates, n_kept_per_row)
cols = sorted(list(forced) + sparse_selected)
return cols + [pad_value] * (n_block_cols - len(cols))
def _shuffle_row(rng, cols, pad_value):
"""Shuffle non-pad entries in-place (for sort_within_row=False)."""
kept = [c for c in cols if c != pad_value]
rng.shuffle(kept)
return kept + [pad_value] * (len(cols) - len(kept))
def generate_sparse_blocks_by_row(
s_q: int,
s_kv: int,
block_size_q: int,
block_size_kv: int,
sparsity: float,
seed: int,
frame: FrameWidths,
pad_value: int = -1,
sort_within_row: bool = True,
) -> list[list[int]]:
"""
Builds the sabi blocks (sparse blocks selection) of a single head, with a "frame" of
always-included blocks and the rest of the blocks randomly selected to meet the target
sparsity
Note: blocks from the frame + sparsely selected blocks are both taken into account when computing the sparsity.
Output:
- shape [n_block_rows][n_block_cols]
- each row: selected cols, then pad_value to the end. Cols are sorted ascending
when sort_within_row=True (default); otherwise shuffled deterministically per seed
(kernel sparse-loop assumes sorted — see SABI_SORTED).
"""
sparsity = max(0.0, min(1.0, float(sparsity)))
n_block_rows = math.ceil(s_q / block_size_q)
n_block_cols = math.ceil(s_kv / block_size_kv)
n_total_blocks = n_block_rows * n_block_cols
if n_block_rows == 0 or n_block_cols == 0:
raise ValueError(f"Invalid block configuration: {s_q=}, {s_kv=}, {block_size_q=}, {block_size_kv=}")
frame_sets = _compute_frame_sets(frame, n_block_rows, n_block_cols)
sparse_forced_col_ids, _, dense_row_ids = frame_sets
n_dense_rows = len(dense_row_ids)
n_sparse_rows = max(n_block_rows - n_dense_rows, 0)
n_frame_forced_blocks = len(sparse_forced_col_ids) * n_sparse_rows + n_block_cols * n_dense_rows
if n_frame_forced_blocks / n_total_blocks > (1.0 - sparsity):
raise ValueError(
f"Target sparsity violation: frame forces {n_frame_forced_blocks} blocks out of "
f"{n_total_blocks} total attention blocks, implying a minimum density of "
f"{n_frame_forced_blocks / n_total_blocks:.2%}, which violates the target sparsity "
f"of {sparsity:.2%}. And this is even before allocating the random blocks inside the "
f"frame. Consider adjusting frame sizes or target sparsity."
)
n_max_kept_blocks = int(math.floor(n_total_blocks * (1.0 - sparsity)))
n_max_kept_blocks_excluding_frame = n_max_kept_blocks - n_frame_forced_blocks
n_kept_row_blocks_per_row = (
math.ceil(n_max_kept_blocks_excluding_frame / n_sparse_rows) if n_sparse_rows > 0 else 0
)
rng = random.Random(seed)
rows: list[list[int]] = []
for r in range(n_block_rows):
if r in dense_row_ids:
cols = list(range(n_block_cols))
else:
cols = _build_sparse_row(rng, frame_sets, n_kept_row_blocks_per_row, n_block_cols, pad_value)
if not sort_within_row:
cols = _shuffle_row(rng, cols, pad_value)
rows.append(cols)
return rows
BlockGrid = namedtuple("BlockGrid", ["s_q", "s_kv", "block_size_q", "block_size_kv"])
def is_block_sparse_pattern_feasible(
grid: BlockGrid, frame, sparsity: float,
*, require_nonempty_per_row: bool = True,
) -> bool:
"""Return True iff `generate_sparse_blocks_by_row` is satisfiable for the given grid.
`grid` bundles (s_q, s_kv, block_size_q, block_size_kv); see `BlockGrid`.
The check returns True iff `generate_sparse_blocks_by_row` will (a) not raise
ValueError and (b) produce a sabi where every Q-row gets at least one kept
block (when `require_nonempty_per_row`). Pure arithmetic — no tensor
allocation — so it is safe to call from pytest parametrize lists at module
import time.
The check mirrors the math inside `generate_sparse_blocks_by_row`:
* frame must fit the sparsity budget:
n_forced_blocks <= floor(n_total * (1 - sparsity))
* if `require_nonempty_per_row` is True (the default — BSA's per-Q-row
LSE sentinel is the cliff we are avoiding), every sparse row must end
up with `>= 1` kept block:
n_max_kept - n_forced >= 1 when n_sparse_rows > 0
Set `require_nonempty_per_row=False` if you intentionally want to test the
"row fully masked" path (and accept BSA returning its sentinel).
"""
sparsity = max(0.0, min(1.0, float(sparsity)))
n_block_rows = math.ceil(grid.s_q / grid.block_size_q)
n_block_cols = math.ceil(grid.s_kv / grid.block_size_kv)
if n_block_rows == 0 or n_block_cols == 0:
return False
n_total = n_block_rows * n_block_cols
if frame is None:
n_forced = 0
n_dense_rows = 0
n_sparse_rows = n_block_rows
else:
if _frame_exceeds_dims(frame, n_block_rows, n_block_cols):
return False
sparse_forced_left = set(range(min(frame.left_cols, n_block_cols)))
sparse_forced_right = set(range(max(n_block_cols - frame.right_cols, 0), n_block_cols))
n_forced_cols = len(sparse_forced_left | sparse_forced_right)
dense_top = set(range(min(frame.top_rows, n_block_rows)))
dense_bottom = set(range(max(n_block_rows - frame.bottom_rows, 0), n_block_rows))
n_dense_rows = len(dense_top | dense_bottom)
n_sparse_rows = max(n_block_rows - n_dense_rows, 0)
n_forced = n_forced_cols * n_sparse_rows + n_block_cols * n_dense_rows
n_max_kept = math.floor(n_total * (1.0 - sparsity))
if n_forced > n_max_kept:
return False
if require_nonempty_per_row and n_sparse_rows > 0 and n_max_kept - n_forced < 1:
return False
return True
def make_block_mask(
s_q: int,
s_kv: int,
block_size_q: int,
block_size_kv: int,
sparse_blocks_by_row: list[list[int]],
device: str = "cpu",
) -> torch.Tensor:
"""
Builds a dense boolean mask [1, 1, s_q, s_kv] from sparse block columns.
True = masked
False = allowed
"""
n_block_rows = math.ceil(s_q / block_size_q)
n_block_cols = math.ceil(s_kv / block_size_kv)
mask = torch.ones((s_q, s_kv), dtype=torch.bool, device=device)
rows_to_process = min(n_block_rows, len(sparse_blocks_by_row))
for r in range(rows_to_process):
row_start = r * block_size_q
row_end = min(row_start + block_size_q, s_q)
for c in sparse_blocks_by_row[r]:
if 0 <= c < n_block_cols:
col_start = c * block_size_kv
col_end = min(col_start + block_size_kv, s_kv)
mask[row_start:row_end, col_start:col_end] = False
return mask.unsqueeze(0).unsqueeze(0)
def generate_sparse_blocks_by_row_per_head(
s_q: int,
s_kv: int,
block_size_q: int,
block_size_kv: int,
sparsity: float,
num_heads: int,
frame: FrameWidths,
base_seed: int,
sort_within_row: bool = True,
) -> list[list[list[int]]]:
"""
Returns sparse blocks per head.
Output shape:
[num_heads][n_block_rows][num_kept_elems]
"""
return [
generate_sparse_blocks_by_row(
s_q=s_q,
s_kv=s_kv,
block_size_q=block_size_q,
block_size_kv=block_size_kv,
sparsity=sparsity,
seed=base_seed + h,
frame=frame,
sort_within_row=sort_within_row,
)
for h in range(num_heads)
]
def make_block_mask_per_head(
s_q: int,
s_kv: int,
block_size_q: int,
block_size_kv: int,
sparse_blocks_by_row_per_head: list[list[list[int]]],
device: str = "cpu",
) -> torch.Tensor:
"""
Builds a dense boolean mask [1, num_heads, s_q, s_kv] from sparse block columns.
True = masked
False = allowed
"""
head_masks = [
make_block_mask(
s_q=s_q,
s_kv=s_kv,
block_size_q=block_size_q,
block_size_kv=block_size_kv,
sparse_blocks_by_row=sparse_blocks_by_row,
device=device,
).squeeze(0).squeeze(0)
for sparse_blocks_by_row in sparse_blocks_by_row_per_head
]
return torch.stack(head_masks, dim=0).unsqueeze(0)
def block_allclose_map(
out: torch.Tensor,
ref: torch.Tensor,
block_h: int,
block_w: int,
rtol: float = 0.01,
atol: float = 0.001,
print_map: bool = False,
) -> torch.Tensor:
"""
Compare two 4D tensors [dim0, dim1, dim2, dim3] in (block_h x block_w) blocks over the last two dims.
Returns a boolean tensor of shape [dim0, dim1, nby, nbx] where each entry indicates whether the
entire block is allclose. Optionally prints a 0/1 block matrix per (batch, head).
- First two dims: batch and head (iterated and printed as headers)
- Last two dims: "drawn" dimensions (split into blocks)
Blocks at the edges can be smaller if dim2 or dim3 is not divisible by block sizes.
"""
if out.shape != ref.shape:
raise ValueError(f"Shape mismatch: out {tuple(out.shape)} vs ref {tuple(ref.shape)}")
if out.ndim != 4:
raise ValueError(f"Expected 4D tensors [dim0, dim1, dim2, dim3], got out.ndim={out.ndim}")
if block_h <= 0 or block_w <= 0:
raise ValueError("block_h and block_w must be positive integers")
dim0, dim1, dim2, dim3 = out.shape
nby = (dim2 + block_h - 1) // block_h
nbx = (dim3 + block_w - 1) // block_w
close = torch.isclose(out, ref, rtol=rtol, atol=atol)
block_ok = close.new_empty((dim0, dim1, nby, nbx), dtype=torch.bool)
for b, h, by, bx in itertools.product(range(dim0), range(dim1), range(nby), range(nbx)):
y0, x0 = by * block_h, bx * block_w
y1, x1 = min(y0 + block_h, dim2), min(x0 + block_w, dim3)
block_ok[b, h, by, bx] = close[b, h, y0:y1, x0:x1].all()
if print_map:
for b, h in itertools.product(range(dim0), range(dim1)):
logger.info(f"batch={b}, head={h}")
grid = block_ok[b, h].to(dtype=torch.int32)
for by in range(nby):
row = " ".join(str(int(v)) for v in grid[by].tolist())
logger.info(row)
logger.info()
return block_ok
def ref_blitz_sparse_attention_fp32(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale_value: float,
atten_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Reference implementation for correctness comparisons.
Scaled dot-product attention in float32 throughout.
Assumes:
- input_layout == "BNSD"
- atten_mask (if provided) is bool broadcastable
to [batch_size, num_heads, s_q, s_kv], where True means "masked out".
"""
q_f = q.to(torch.float32)
k_f = k.to(torch.float32)
v_f = v.to(torch.float32)
attn_scores = torch.matmul(q_f, k_f.transpose(-1, -2))
attn_scores = attn_scores * scale_value
if atten_mask is not None:
if atten_mask.dtype == torch.bool:
attn_scores = attn_scores.masked_fill(
atten_mask, torch.finfo(attn_scores.dtype).min
)
elif atten_mask.dtype in (torch.int8, torch.uint8):
attn_scores = attn_scores.masked_fill(
atten_mask == 1, torch.finfo(attn_scores.dtype).min
)
attn_probs = torch.softmax(attn_scores, dim=-1)
out = torch.matmul(attn_probs, v_f)
return out.to(dtype=q.dtype)
def gen_pfa_inputs(
batch_size: int, num_heads: int, s_q: int, s_kv: int, head_dim: int,
device: str = "npu:0", dtype: torch.dtype = DTYPE
):
"""
Generate common inputs for blitz_sparse_attention or prompt_flash_attention.
"""
q = torch.randn(batch_size, num_heads, s_q, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, num_heads, s_kv, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, num_heads, s_kv, head_dim, dtype=dtype, device=device)
actseqlen = [s_q] * batch_size
actseqlenkv = [s_kv] * batch_size
return q, k, v, actseqlen, actseqlenkv
def create_attention_mask(spec: MaskSpec, device: str, emit_atten_mask: bool = True) -> AttentionMask:
"""
Build the block-sparse `sabi` tensor (and the matching token-level
`pfa_atten_mask` if `emit_atten_mask`) for the batched block-sparse pattern
that BSA expects. Each batch entry gets an independent random sabi seeded
by `BLOCK_MASK_SEED + bidx`; `spec.frame` forces a fixed sink/streaming
border on top of the sparsity budget.
"""
scale = 1.0 / math.sqrt(float(spec.d))
pre_tok, post_tok = 2147483647, 0
bsa_atten_mask = None
sm = 0
per_batch_head_block_indices = [
generate_sparse_blocks_by_row_per_head(
spec.s_q, spec.s_kv, spec.block_size_q, spec.block_size_kv,
spec.sparsity, spec.h, spec.frame, BLOCK_MASK_SEED + bidx,
sort_within_row=SABI_SORTED,
)
for bidx in range(spec.b)
]
sabi = torch.tensor(per_batch_head_block_indices, dtype=torch.uint16, device=device)
pfa_atten_mask = None
if emit_atten_mask and spec.sparsity > 0:
pfa_atten_mask = torch.cat([
make_block_mask_per_head(
spec.s_q, spec.s_kv, spec.block_size_q, spec.block_size_kv,
bi, device=device,
)
for bi in per_batch_head_block_indices
], dim=0)
return AttentionMask(
pfa_atten_mask=pfa_atten_mask, bsa_atten_mask=bsa_atten_mask, sabi=sabi,
sparse_mode=sm, scale=scale, pre_tokens=pre_tok, next_tokens=post_tok,
)
def _run_timed(kernel_fn: Callable, input_sets: List, n_warmup: int, n_repeat: int):
"""
Utility function to run a kernel multiple times on the provided inputs
and report the average latency in microseconds. Warmup is appplied first (not timed).
"""
for args in input_sets[:n_warmup]:
kernel_fn(*args)
torch.npu.synchronize()
start = torch.npu.Event(enable_timing=True)
end = torch.npu.Event(enable_timing=True)
start.record()
for args in input_sets[n_warmup:]:
kernel_fn(*args)
end.record()
torch.npu.synchronize()
return start.elapsed_time(end) / n_repeat * 1000.0
def _check_correctness(our_fn: Callable, ref_fn: Callable, b: int, h: int, s_q: int, s_kv: int, d: int, device: str
) -> str:
"""Run both implementations on shared inputs and return 'yes'/'no'."""
q, k, v, seq, seqkv = gen_pfa_inputs(b, h, s_q, s_kv, d, device=device, dtype=DTYPE)
out_our = our_fn(q, k, v, seq, seqkv).cpu()
out_ref = ref_fn(q, k, v, seq, seqkv).cpu()
if PRINT_OUTPUTS:
logger.info("OURS: ", out_our.shape)
logger.info(out_our)
logger.info("REF: ", out_ref.shape)
logger.info(out_ref)
equal = torch.allclose(out_our, out_ref, rtol=0.01, atol=0.001)
if not equal and PRINT_BLOCK_EQUALITY:
block_allclose_map(out_our, out_ref,
block_h=PRINT_HEIGHT, block_w=PRINT_WIDTH,
rtol=0.01, atol=0.001, print_map=True)
return "yes" if equal else "no"
def _fmt_or_na(value, width, spec=".2f"):
"""Format a number with given width/spec, or right-align 'N/A' if None."""
if value is None:
return f"{'N/A':>{width}}"
return f"{value:{width}{spec}}"
def _make_our_fn(mask: AttentionMask, h, block_shape):
def fn(q, k, v, seq, seqkv):
out, _ = torch_bsa.blitz_sparse_attention(
q, k, v,
sabi=mask.sabi, actual_seq_lengths=seq,
actual_seq_lengths_kv=seqkv, num_heads=h, num_key_value_heads=h,
input_layout=INPUT_LAYOUT, scale_value=mask.scale,
atten_mask=mask.bsa_atten_mask, sparse_mode=mask.sparse_mode,
pre_tokens=mask.pre_tokens, next_tokens=mask.next_tokens,
softmax_lse_flag=False,
block_shape=list(block_shape),
)
return out
return fn
def _make_ref_fn(h, scale, atten_mask, run_ref_sparsity_0):
def fn(q, k, v, seq, seqkv):
return ref_blitz_sparse_attention_launcher(
torch_reference=TORCH_REFERENCE,
pfa_inputs=(q, k, v, h, scale, atten_mask, INPUT_LAYOUT),
force_dense_sm=run_ref_sparsity_0,
)
return fn
_BenchCell = namedtuple(
"_BenchCell",
["bs_label", "block_shape", "frame_for_shape", "b", "h", "s_q", "s_kv", "d", "sparsity"],
)
_BenchCfg = namedtuple("_BenchCfg", ["run_our", "run_ref", "n_warmup", "n_repeat"])
def _format_frame_str(frame, sparsity):
"""Compact "(L,R,T,B)" or "-" when frame is unset or sparsity is dense."""
if frame is None or sparsity == 0:
return "-"
return (f"({frame.left_cols},{frame.right_cols},"
f"{frame.top_rows},{frame.bottom_rows})")
def _run_one_benchmark_cell(cell: _BenchCell, cfg: _BenchCfg):
"""Build the mask, time our/ref kernels, and emit one results row."""
mask = create_attention_mask(
MaskSpec(b=cell.b, h=cell.h, s_q=cell.s_q, s_kv=cell.s_kv, d=cell.d,
block_size_q=cell.block_shape[0], block_size_kv=cell.block_shape[1],
sparsity=cell.sparsity, frame=cell.frame_for_shape),
device=DEVICE, emit_atten_mask=cfg.run_ref,
)
if PRINT_MASK and mask.pfa_atten_mask is not None:
logger.info(mask.pfa_atten_mask.int())
logger.info(mask.pfa_atten_mask.shape)
run_ref_sparsity_0 = cell.sparsity == 0
our_fn = _make_our_fn(mask, cell.h, cell.block_shape)
ref_fn = _make_ref_fn(cell.h, mask.scale, mask.pfa_atten_mask, run_ref_sparsity_0)
are_equal_ref = "N/A"
if cfg.run_our and cfg.run_ref or run_ref_sparsity_0:
are_equal_ref = _check_correctness(
our_fn, ref_fn, cell.b, cell.h, cell.s_q, cell.s_kv, cell.d, device=DEVICE,
)
inputs = [gen_pfa_inputs(cell.b, cell.h, cell.s_q, cell.s_kv, cell.d, device=DEVICE, dtype=DTYPE)
for _ in range(cfg.n_warmup + cfg.n_repeat)]
our_duration = _run_timed(our_fn, inputs, cfg.n_warmup, cfg.n_repeat) if cfg.run_our else None
ref_duration = (_run_timed(ref_fn, inputs, cfg.n_warmup, cfg.n_repeat)
if cfg.run_ref or run_ref_sparsity_0 else None)
frame_str = _format_frame_str(cell.frame_for_shape, cell.sparsity)
logger.info(
f"{cell.bs_label:>13} {cell.h:>3} {cell.b:>3} {cell.s_q:>6} {cell.s_kv:>6} {cell.d:>4} "
f"{frame_str:>15} {cell.sparsity:>9.2f} "
f"{are_equal_ref:>15} {_fmt_or_na(ref_duration, 18)} {_fmt_or_na(our_duration, 18)}"
)
def benchmark_blitz_sparse_attention():
cfg = _BenchCfg(run_our=True, run_ref=RUN_REFERENCE,
n_warmup=N_WARMUP, n_repeat=N_REPEATS)
if not cfg.run_our and not cfg.run_ref:
logger.info("Nothing to run, must set run_our=True or run_ref=True")
return
logger.info("=" * 120)
logger.info(f" {DTYPE=} {INPUT_LAYOUT=} {SABI_SORTED=} {TORCH_REFERENCE=}")
logger.info("=" * 120)
logger.info(
f"{'block_shape':>13} {'H':>3} {'B':>3} {'s_q':>6} {'s_kv':>6} {'D':>4} "
f"{'frame(L,R,T,B)':>15} {'sparsity':>9} "
f"{'Outputs_equal':>15} {'Ref_Latency_[usec]':>18} {'Our_Latency_[usec]':>18}"
)
logger.info("-" * 120)
for block_shape in BLOCK_SHAPES:
bs_q, bs_kv = block_shape
frame_for_shape = FRAMES_BY_BLOCK_SHAPE.get(block_shape, None)
for b, h, s_kv, d, sparsity in itertools.product(B_VALS, H_VALS, S_VALS, D_VALS, SPARSITY_VALS):
_run_one_benchmark_cell(
_BenchCell(bs_label=f"{bs_q}x{bs_kv}", block_shape=block_shape,
frame_for_shape=frame_for_shape,
b=b, h=h, s_q=s_kv, s_kv=s_kv, d=d, sparsity=sparsity),
cfg,
)
logger.info("=" * 120)
if __name__ == "__main__":
torch.npu.set_device(DEVICE)
benchmark_blitz_sparse_attention()