"""
npu_block_sparse_attention 正向算子单测。
测试场景覆盖:
- BNSD/TND 布局、actual_seq_lengths 可选/必传
- inner_precise=0/1(fp32/fp16 中间 softmax)
- softmax_lse_flag=0/1
- float16、bfloat16 数据类型(bf16 时 inner_precise 仅支持 0)
- GQA 约束 G=N1/N2:G<128 且 128%G==0(G=2,4,8,16,32,64)
- 多块稀疏(block_shape=[8,128] 产生 4x1 块)
- 稀疏掩码(对角块为 1,其余为 0)
"""
import math
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices, SkipIfNotGteCANNVersion
DTYPE = torch.float16
B, S, N, D = 2, 32, 8, 64
NUM_KV_HEADS = 8
BLOCK_SHAPE = [128, 128]
def _softmax_np(x):
"""Softmax with numerical stability. When all inputs are -inf (masked), return zeros."""
x = x.astype(np.float32)
x_max = x.max(axis=-1, keepdims=True)
x = x - x_max
y = np.exp(x)
s = y.sum(axis=-1, keepdims=True)
s = np.where(s > 1e-10, s, 1.0)
return y / s
def _logsumexp_np(scores):
"""对一行 scores 计算 log(sum(exp(scores))),masked 位置为 -1e10 时仍稳定."""
s = scores.astype(np.float32)
m = np.max(s)
if m <= -1e9:
return m
return float(m + np.log(np.sum(np.exp(s - m))))
def cpu_block_sparse_attention_bnsd_with_lse(query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads):
"""CPU 标杆:BNSD,返回 (attention_out, softmax_lse),用于 softmaxLse 校验."""
q = query.cpu().to(torch.float32).numpy()
k = key.cpu().to(torch.float32).numpy()
v = value.cpu().to(torch.float32).numpy()
mask = block_sparse_mask.cpu().numpy()
B, N, S, D = q.shape
N2 = k.shape[1]
block_x, block_y = int(block_shape[0]), int(block_shape[1])
ceil_q = (S + block_x - 1) // block_x
ceil_kv = (S + block_y - 1) // block_y
out = np.zeros((B, N, S, D), dtype=np.float32)
lse = np.zeros((B, N, S, 1), dtype=np.float32)
for b in range(B):
for n in range(N):
n2 = n % N2
for s1 in range(S):
i_block = s1 // block_x
scores = np.full(S, -1e10, dtype=np.float32)
for s2 in range(S):
j_block = s2 // block_y
if i_block < ceil_q and j_block < ceil_kv and mask[b, n, i_block, j_block] != 0:
scores[s2] = float(scale_value) * np.dot(q[b, n, s1, :], k[b, n2, s2, :])
lse[b, n, s1, 0] = _logsumexp_np(scores)
probs = _softmax_np(scores)
out[b, n, s1, :] = np.dot(probs, v[b, n2, :, :])
return torch.from_numpy(out).to(query.dtype), torch.from_numpy(lse).to(torch.float32)
def cpu_block_sparse_attention_bnsd(query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads):
"""
CPU 标杆:BNSD layout,block_sparse_mask[b,n,i,j]=1 表示 q_block i 与 kv_block j 参与计算.
query (B,N,S,D), key/value (B,N2,S,D), block_sparse_mask (B,N,ceil_q,ceil_kv).
"""
q = query.cpu().to(torch.float32).numpy()
k = key.cpu().to(torch.float32).numpy()
v = value.cpu().to(torch.float32).numpy()
mask = block_sparse_mask.cpu().numpy()
B, N, S, D = q.shape
N2 = k.shape[1]
block_x, block_y = int(block_shape[0]), int(block_shape[1])
ceil_q = (S + block_x - 1) // block_x
ceil_kv = (S + block_y - 1) // block_y
out = np.zeros((B, N, S, D), dtype=np.float32)
for b in range(B):
for n in range(N):
n2 = n % N2
for s1 in range(S):
i_block = s1 // block_x
scores = np.full(S, -1e10, dtype=np.float32)
for s2 in range(S):
j_block = s2 // block_y
if i_block < ceil_q and j_block < ceil_kv and mask[b, n, i_block, j_block] != 0:
scores[s2] = float(scale_value) * np.dot(q[b, n, s1, :], k[b, n2, s2, :])
probs = _softmax_np(scores)
out[b, n, s1, :] = np.dot(probs, v[b, n2, :, :])
return torch.from_numpy(out).to(query.dtype)
def cpu_block_sparse_attention_tnd_with_lse(query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads):
"""CPU 标杆:TND,返回 (attention_out, softmax_lse),用于 softmaxLse 校验."""
q = query.cpu().to(torch.float32).numpy()
k = key.cpu().to(torch.float32).numpy()
v = value.cpu().to(torch.float32).numpy()
mask = block_sparse_mask.cpu().numpy()
T, N, D = q.shape
N2 = k.shape[1]
block_x, block_y = int(block_shape[0]), int(block_shape[1])
ceil_q = (T + block_x - 1) // block_x
ceil_kv = (T + block_y - 1) // block_y
out = np.zeros((T, N, D), dtype=np.float32)
lse = np.zeros((T, N, 1), dtype=np.float32)
for n in range(N):
n2 = n % N2
for s1 in range(T):
i_block = s1 // block_x
scores = np.full(T, -1e10, dtype=np.float32)
for s2 in range(T):
j_block = s2 // block_y
if i_block < ceil_q and j_block < ceil_kv and mask[0, n, i_block, j_block] != 0:
scores[s2] = float(scale_value) * np.dot(q[s1, n, :], k[s2, n2, :])
lse[s1, n, 0] = _logsumexp_np(scores)
probs = _softmax_np(scores)
out[s1, n, :] = np.dot(probs, v[:, n2, :])
return torch.from_numpy(out).to(query.dtype), torch.from_numpy(lse).to(torch.float32)
def cpu_block_sparse_attention_tnd(query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads):
"""CPU 标杆:TND layout,batch=1,T 为总 token 数."""
q = query.cpu().to(torch.float32).numpy()
k = key.cpu().to(torch.float32).numpy()
v = value.cpu().to(torch.float32).numpy()
mask = block_sparse_mask.cpu().numpy()
T, N, D = q.shape
N2 = k.shape[1]
block_x, block_y = int(block_shape[0]), int(block_shape[1])
ceil_q = (T + block_x - 1) // block_x
ceil_kv = (T + block_y - 1) // block_y
out = np.zeros((T, N, D), dtype=np.float32)
for n in range(N):
n2 = n % N2
for s1 in range(T):
i_block = s1 // block_x
scores = np.full(T, -1e10, dtype=np.float32)
for s2 in range(T):
j_block = s2 // block_y
if i_block < ceil_q and j_block < ceil_kv and mask[0, n, i_block, j_block] != 0:
scores[s2] = float(scale_value) * np.dot(q[s1, n, :], k[s2, n2, :])
probs = _softmax_np(scores)
out[s1, n, :] = np.dot(probs, v[:, n2, :])
return torch.from_numpy(out).to(query.dtype)
class TestNPUBlockSparseAttention(TestCase):
"""Test npu_block_sparse_attention 正向,与 CPU 标杆对比."""
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_cpu_compare(self, device="npu"):
"""BNSD:NPU 与 CPU 标杆对比,统一使用 BLOCK_SHAPE;校验 attention_out 与 softmaxLse."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out, cpu_lse = cpu_block_sparse_attention_bnsd_with_lse(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, npu_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
softmax_lse_flag=1,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
npu_lse_cpu = npu_lse.cpu().float()
cpu_lse_f = cpu_lse.cpu().float()
self.assertRtolEqual(cpu_lse_f, npu_lse_cpu, prec=0.005, prec16=0.01)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_with_optional_args_cpu_compare(self, device="npu"):
"""BNSD 显式传 block_shape/actual_seq_lengths:NPU 与 CPU 标杆对比;校验 softmaxLse."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out, cpu_lse = cpu_block_sparse_attention_bnsd_with_lse(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, npu_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
softmax_lse_flag=1,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
npu_lse_cpu = npu_lse.cpu().float()
cpu_lse_f = cpu_lse.cpu().float()
self.assertRtolEqual(cpu_lse_f, npu_lse_cpu, prec=0.005, prec16=0.01)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_tnd_cpu_compare(self, device="npu"):
"""TND:NPU 与 CPU 标杆对比,必传 actual_seq_lengths/actual_seq_lengths_kv;校验 softmaxLse."""
torch.npu.empty_cache()
T = S
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (T + block_shape[0] - 1) // block_shape[0]
ceil_kv = (T + block_shape[1] - 1) // block_shape[1]
query = torch.randn(T, N, D, dtype=DTYPE)
key = torch.randn(T, num_kv_heads, D, dtype=DTYPE)
value = torch.randn(T, num_kv_heads, D, dtype=DTYPE)
block_sparse_mask = torch.ones(1, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out, cpu_lse = cpu_block_sparse_attention_tnd_with_lse(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, npu_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, BLOCK_SHAPE,
q_input_layout="TND", kv_input_layout="TND",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[T], actual_seq_lengths_kv=[T],
softmax_lse_flag=1,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
npu_lse_cpu = npu_lse.cpu().float()
cpu_lse_f = cpu_lse.cpu().float()
self.assertRtolEqual(cpu_lse_f, npu_lse_cpu, prec=0.005, prec16=0.01)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_inner_precise0_cpu_compare(self, device="npu"):
"""BNSD inner_precise=0:fp32 中间 softmax,与 CPU 标杆对比."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out = cpu_block_sparse_attention_bnsd(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=0,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_softmax_lse_flag0(self, device="npu"):
"""BNSD softmax_lse_flag=0:不输出 softmax_lse,验证返回形状正确."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=DTYPE).to(device)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE).to(device)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE).to(device)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8).to(device)
attention_out, softmax_lse = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
softmax_lse_flag=0,
)
self.assertEqual(attention_out.shape, (B, N, S, D))
self.assertEqual(attention_out.dtype, DTYPE)
self.assertEqual(softmax_lse.shape, (B, N, S, 1))
self.assertEqual(softmax_lse.dtype, torch.float32)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_bfloat16_cpu_compare(self, device="npu"):
"""BNSD bfloat16:与 CPU 标杆对比,验证 bfloat16 支持."""
if hasattr(torch.npu, 'is_bf16_supported') and not torch.npu.is_bf16_supported():
self.skipTest("NPU bfloat16 not supported")
torch.npu.empty_cache()
dtype_bf16 = torch.bfloat16
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=dtype_bf16)
key = torch.randn(B, num_kv_heads, S, D, dtype=dtype_bf16)
value = torch.randn(B, num_kv_heads, S, D, dtype=dtype_bf16)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out = cpu_block_sparse_attention_bnsd(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=0,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_gqa_g_constraint_cpu_compare(self, device="npu"):
"""GQA 场景 G=N1/N2:G<128 且 128%G==0,与 CPU 标杆对比."""
torch.npu.empty_cache()
scale_value = 1.0 / math.sqrt(D)
block_shape = BLOCK_SHAPE
g_configs = [
(1, 8, 8),
(8, 8, 1),
]
for G, num_heads, num_kv_heads in g_configs:
with self.subTest(G=G, N1=num_heads, N2=num_kv_heads):
assert num_heads % num_kv_heads == 0 and num_heads // num_kv_heads == G
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, num_heads, S, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
block_sparse_mask = torch.ones(B, num_heads, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out = cpu_block_sparse_attention_bnsd(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
torch.npu.empty_cache()
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_multi_block_cpu_compare(self, device="npu"):
"""BNSD 多块:block_shape=[8,128] 产生 4x1 块,验证多块稀疏."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = [8, 128]
ceil_q = (S + block_shape[0] - 1) // block_shape[0]
ceil_kv = (S + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, S, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, S, D, dtype=DTYPE)
block_sparse_mask = torch.ones(B, N, ceil_q, ceil_kv, dtype=torch.int8)
cpu_out = cpu_block_sparse_attention_bnsd(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[S] * B, actual_seq_lengths_kv=[S] * B,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
@SkipIfNotGteCANNVersion("9.0.0")
@SupportedDevices(['Ascend910B'])
def test_npu_block_sparse_attention_bnsd_sparse_mask_cpu_compare(self, device="npu"):
"""BNSD 稀疏掩码:每个 q_block 仅 attend 一个 kv_block(i -> i%ceil_kv),验证稀疏模式."""
torch.npu.empty_cache()
num_kv_heads = NUM_KV_HEADS
scale_value = 1.0 / math.sqrt(D)
block_shape = [32, 128]
s_sparse = 256
ceil_q = (s_sparse + block_shape[0] - 1) // block_shape[0]
ceil_kv = (s_sparse + block_shape[1] - 1) // block_shape[1]
query = torch.randn(B, N, s_sparse, D, dtype=DTYPE)
key = torch.randn(B, num_kv_heads, s_sparse, D, dtype=DTYPE)
value = torch.randn(B, num_kv_heads, s_sparse, D, dtype=DTYPE)
block_sparse_mask = torch.zeros(B, N, ceil_q, ceil_kv, dtype=torch.int8)
for i in range(ceil_q):
block_sparse_mask[:, :, i, i % ceil_kv] = 1
cpu_out = cpu_block_sparse_attention_bnsd(
query, key, value, block_sparse_mask, block_shape, scale_value, num_kv_heads)
query = query.to(device)
key = key.to(device)
value = value.to(device)
block_sparse_mask = block_sparse_mask.to(device)
npu_out, _ = torch_npu.npu_block_sparse_attention(
query, key, value, block_sparse_mask, block_shape,
q_input_layout="BNSD", kv_input_layout="BNSD",
num_key_value_heads=num_kv_heads, scale_value=scale_value, inner_precise=1,
actual_seq_lengths=[s_sparse] * B, actual_seq_lengths_kv=[s_sparse] * B,
)
npu_out_cpu = npu_out.cpu().to(torch.float32)
cpu_out_f = cpu_out.cpu().to(torch.float32)
self.assertRtolEqual(cpu_out_f, npu_out_cpu, prec=0.005, prec16=0.005)
if __name__ == "__main__":
run_tests()