import os
import sys
import math
import unittest
import numpy as np
import torch
import torch_npu
from mindiesd.utils.get_platform import is_a5_device
if os.environ.get("MINDIE_TEST_MODE", "ALL") != "CPU":
from mindiesd.layers.register_ops import _load_mindie_ops_library
_load_mindie_ops_library()
def block_sparse_attention_cpu(query, key, value, block_sparse_mask, blocksize=128):
"""CPU reference: block_sparse_mask (int8 [B,N,q_blocks,kv_blocks]); 1=attend, 0=skip."""
bs, nq, seq, dim = query.shape
nkv = key.shape[1]
gqa = nq // nkv
output = torch.zeros(bs, nq, seq, dim, dtype=torch.float32)
query_f = query.float().cpu().numpy()
key_f = key.float().cpu().numpy()
value_f = value.float().cpu().numpy()
mask_np = block_sparse_mask.cpu().numpy()
for bi in range(bs):
for ni in range(nq):
num_blocks = math.ceil(seq / blocksize)
for s1 in range(num_blocks):
mask_block = mask_np[bi, ni, s1, :num_blocks]
mask_seq = np.repeat(mask_block, blocksize)[:seq].astype(bool)
start = s1 * blocksize
end = min((s1 + 1) * blocksize, seq)
q = query_f[bi, ni, start:end]
k_idx = ni // gqa
k = key_f[bi, k_idx][mask_seq]
v = value_f[bi, k_idx][mask_seq]
if k.shape[0] == 0:
out = np.zeros((end - start, dim), dtype=np.float32)
else:
p = q @ k.T / np.sqrt(dim)
p = p - p.max(axis=-1, keepdims=True)
exp_p = np.exp(p)
attn = exp_p / (exp_p.sum(axis=-1, keepdims=True) + 1e-12)
out = attn @ v
output[bi, ni, start:end] = torch.from_numpy(out)
return output
def ref_compare(golden, actual, err):
golden = golden.to(torch.float32)
golden_nmax = torch.clamp(torch.abs(golden), min=1)
abs_error = torch.abs(actual.to(torch.float32) - golden)
EB = torch.mean(abs_error / golden_nmax)
result = (abs_error <= err * golden_nmax).all() and EB <= err / 2
return EB.item(), result.item(), abs_error.max().item()
def make_block_sparse_mask(batch, head_num, seq_len, sparse_size, sparsity=0.5, seed=42):
"""Generate random int8 block_sparse_mask [B, N, q_blocks, kv_blocks]."""
rng = np.random.default_rng(seed)
q_blocks = math.ceil(seq_len / sparse_size)
kv_blocks = math.ceil(seq_len / sparse_size)
mask = (rng.random((batch, head_num, q_blocks, kv_blocks)) > sparsity).astype(np.int8)
for b in range(batch):
for n in range(head_num):
for q in range(q_blocks):
if mask[b, n, q].sum() == 0:
mask[b, n, q, 0] = 1
return torch.from_numpy(mask)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU",
"Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.",
)
@unittest.skipIf(not is_a5_device(), "Block Sparse Attention requires A5 (950) NPU.")
class TestNpuBlockSparseAttentionNPU(unittest.TestCase):
def setUp(self):
self.device = torch.device("npu:0")
torch.npu.set_device(self.device)
self.batch = 1
self.head_num = 1
self.head_dim = 128
self.seq_len = 75392
self.sparse_size = 128
self.scale = self.head_dim**-0.5
dev_name = torch.npu.get_device_properties(self.device).name
self.inner_precise = 4 if "950" in dev_name else 1
def _full_mask(self):
"""All-ones block_sparse_mask [B, N, q_blocks, kv_blocks]."""
q_blocks = math.ceil(self.seq_len / self.sparse_size)
kv_blocks = math.ceil(self.seq_len / self.sparse_size)
return torch.ones(self.batch, self.head_num, q_blocks, kv_blocks, dtype=torch.int8)
def _call_op(
self,
q,
k,
v,
mask,
layout="BNSD",
actual_seq_lengths=None,
actual_seq_lengths_kv=None,
softmax_lse_flag=0,
block_shape=None,
q_dequant_scale=None,
k_dequant_scale=None,
v_dequant_scale=None,
):
if actual_seq_lengths is None:
actual_seq_lengths = [self.seq_len] * self.batch
if actual_seq_lengths_kv is None:
actual_seq_lengths_kv = [self.seq_len] * self.batch
if block_shape is None:
block_shape = [self.sparse_size, self.sparse_size]
kwargs = dict(
query=q.to(self.device),
key=k.to(self.device),
value=v.to(self.device),
block_sparse_mask=mask.to(self.device),
block_shape=block_shape,
q_input_layout=layout,
kv_input_layout=layout,
num_key_value_heads=self.head_num,
scale_value=self.scale,
inner_precise=self.inner_precise,
softmax_lse_flag=softmax_lse_flag,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
)
if q_dequant_scale is not None:
kwargs.update(
q_dequant_scale=q_dequant_scale.to(self.device),
k_dequant_scale=k_dequant_scale.to(self.device),
v_dequant_scale=v_dequant_scale.to(self.device),
)
return torch.ops.mindiesd.block_sparse_attention(**kwargs)
def test_smoke_bnsd(self):
"""BNSD smoke test: output shape matches query."""
B, N, S, D = self.batch, self.head_num, self.seq_len, self.head_dim
q = torch.randn(B, N, S, D, dtype=torch.float16)
k = torch.randn(B, N, S, D, dtype=torch.float16)
v = torch.randn(B, N, S, D, dtype=torch.float16)
mask = self._full_mask()
attn_out, lse = self._call_op(q, k, v, mask, layout="BNSD")
self.assertEqual(tuple(attn_out.shape), (B, N, S, D))
self.assertEqual(attn_out.dtype, torch.float16)
def test_smoke_tnd(self):
"""TND smoke test: output shape is [T, N, D]."""
B, N, S, D = self.batch, self.head_num, self.seq_len, self.head_dim
T = B * S
q = torch.randn(T, N, D, dtype=torch.float16)
k = torch.randn(T, N, D, dtype=torch.float16)
v = torch.randn(T, N, D, dtype=torch.float16)
mask = self._full_mask()
seq_lens = [S] * B
attn_out, lse = self._call_op(
q,
k,
v,
mask,
layout="TND",
actual_seq_lengths=seq_lens,
actual_seq_lengths_kv=seq_lens,
)
self.assertEqual(tuple(attn_out.shape), (T, N, D))
self.assertEqual(attn_out.dtype, torch.float16)
def test_smoke_bnsd_fp8(self):
"""FP8 BNSD smoke test: FP8 QKV + dequant scales → BF16 output."""
B, N, S, D = self.batch, self.head_num, self.seq_len, self.head_dim
from mindiesd.layers.quant.block_quant import fa_block_quant_preprocess
q_bf16 = torch.randn(B, N, S, D, dtype=torch.bfloat16).npu()
k_bf16 = torch.randn(B, N, S, D, dtype=torch.bfloat16).npu()
v_bf16 = torch.randn(B, N, S, D, dtype=torch.bfloat16).npu()
q_block, kv_block = 128, 256
fp8_dtype = torch_npu.float8_e4m3fn
q_fp8, q_scale = fa_block_quant_preprocess(q_bf16, block_size=q_block, dst_type=fp8_dtype, layout="BNSD")
k_fp8, k_scale = fa_block_quant_preprocess(k_bf16, block_size=kv_block, dst_type=fp8_dtype, layout="BNSD")
v_fp8, v_scale = fa_block_quant_preprocess(v_bf16, block_size=kv_block, dst_type=fp8_dtype, layout="BNSD")
q_blocks = math.ceil(self.seq_len / q_block)
kv_blocks = math.ceil(self.seq_len / kv_block)
mask = torch.ones(self.batch, self.head_num, q_blocks, kv_blocks, dtype=torch.int8)
attn_out, lse = self._call_op(
q_fp8,
k_fp8,
v_fp8,
mask,
layout="BNSD",
block_shape=[q_block, kv_block],
q_dequant_scale=q_scale,
k_dequant_scale=k_scale,
v_dequant_scale=v_scale,
)
self.assertEqual(tuple(attn_out.shape), (B, N, S, D))
self.assertEqual(attn_out.dtype, torch.bfloat16)
def test_bnsd_bf16_no_scales(self):
"""BF16 BNSD without dequant scales: backward compatible with V1 behavior."""
B, N, S, D = self.batch, self.head_num, self.seq_len, self.head_dim
q = torch.randn(B, N, S, D, dtype=torch.bfloat16)
k = torch.randn(B, N, S, D, dtype=torch.bfloat16)
v = torch.randn(B, N, S, D, dtype=torch.bfloat16)
mask = self._full_mask()
attn_out, lse = self._call_op(q, k, v, mask, layout="BNSD")
self.assertEqual(tuple(attn_out.shape), (B, N, S, D))
self.assertEqual(attn_out.dtype, torch.bfloat16)
if __name__ == "__main__":
unittest.main(argv=[""], exit=False)