import torch
import torch_npu
import math
import numpy as np
import npu_ops_transformer
from npu_ops_transformer.ops import npu_flash_attn
torch.manual_seed(42)
B_list = [1]
numHeads = 1
numKeyValueHeads = 1
Sq_list = [128]
Skv = 128
D = 64
type=torch.float16
for B in B_list:
for Sq in Sq_list:
query = torch.randn(B, numHeads, Sq, D, dtype=type).npu()
key = torch.randn(B, numKeyValueHeads, Skv, D, dtype=type).npu()
value = torch.randn(B, numKeyValueHeads, Skv, D, dtype=type).npu()
scale_value = 1/math.sqrt(float(D))
actual_seq_lengths_kv = [Skv]*B
attention_mask = torch.tril(torch.ones(2048,2048)).to(torch.bool).npu()
for _ in range(1):
out, _ = npu_flash_attn(
query, key, value,
block_table=None,
cu_seqlens_q=None,
cu_seqlens_kv=None,
seqused_q=None,
seqused_kv=None,
sinks=None,
metadata=None,
softmax_scale=scale_value,
mask_mode=0,
win_left=0,
win_right=0,
max_seqlen_q=0,
max_seqlen_kv=0,
layout_q = "BNSD",
layout_kv = "BNSD",
layout_out = "BNSD",
return_softmax_lse=0)
out = out.cpu()
print("************end*************", out, out.shape)