import subprocess
import torch
import torch_npu
import npu_ops_transformer
# 初始化 NPU
torch_npu.npu.set_device(0)
# 创建输入 tensor(可选参数可以传 None)
batch_size = 4
max_seqlen_q = 128
max_seqlen_kv = 128
num_heads_q = 1
num_heads_kv = 1
head_dim = 128
mask_mode = 0 # 0: band mask, 3: rightDownCausal, 4: bandCausal
win_left = None
win_right = None
layout_q = "TND"
layout_kv = "TND"
layout_out = "TND"
# 创建实际序列长度 tensor(可选)
cu_seqlens_q = torch.arange(0, max_seqlen_q * batch_size + max_seqlen_q , max_seqlen_q).to(dtype=torch.int32).npu()
cu_seqlens_kv = torch.arange(0, max_seqlen_kv * batch_size + max_seqlen_kv , max_seqlen_kv).to(dtype=torch.int32).npu()
seqused_q = torch.full((batch_size,), max_seqlen_q).to(dtype=torch.int32).npu()
seqused_kv = torch.full((batch_size,), max_seqlen_kv).to(dtype=torch.int32).npu()
print("cu_seqlens_q:",cu_seqlens_q)
print("seqused_q",seqused_q)
print("cu_seqlens_kv:",cu_seqlens_kv)
print("seqused_kv",seqused_kv)
print(" num_heads_q", num_heads_q)
print("num_heads_kv", num_heads_kv)
print("seqused_kv",seqused_kv)
print("seqused_kv",seqused_kv)
print("seqused_kv",seqused_kv)
# 调用算子
# result = npu_ops_transformer.ops.npu_flash_attn_metadata(
result = torch.ops.npu_ops_transformer.npu_flash_attn_metadata(
# cu_seqlens_q = cu_seqlens_q,
# cu_seqlens_kv = cu_seqlens_kv,
# seqused_q = seqused_q,
# seqused_kv = seqused_kv,
cu_seqlens_q = cu_seqlens_q,
cu_seqlens_kv = cu_seqlens_kv,
# seqused_q = seqused_q,
# seqused_kv = seqused_kv,
num_heads_q = num_heads_q,
num_heads_kv = num_heads_kv,
head_dim = head_dim,
batch_size = batch_size,
# max_seqlen_q = max_seqlen_q,
# max_seqlen_kv = max_seqlen_kv,
mask_mode = mask_mode,
win_left = win_left,
win_right = win_right,
layout_q = layout_q,
layout_kv = layout_kv,
layout_out = layout_out
# batch_size = None,
# max_seqlen_q = None,
# max_seqlen_kv = None,
# mask_mode = None,
# win_left = None,
# win_right = None,
# layout_q = None,
# layout_kv = None,
# layout_out = None
)
# 验证结果
print(f"Result shape: {result.shape}")
print(f"Result dtype: {result.dtype}")
print(f"Result device: {result.device}")
print(f"First 10 values: {result[:10].cpu().tolist()}")
# 断言验证
shape_size = (((36 + 72) * batch_size * num_heads_kv + 1) * 16 + 4095) // 4096 * 4096
assert result.shape == (shape_size,), f"Expected shape ({shape_size} ,), got {result.shape}"
assert result.dtype == torch.int32, f"Expected dtype int32, got {result.dtype}"
assert result.device.type == 'npu', f"Expected device npu, got {result.device.type}"
print("✅ Test passed!")
result =result.cpu()
print("sectionNum:",result[0])
aicNum = 36
aivNum = 72
faSize = 16
fdSize = 16
sectionNum = result[0]
for sectionId in range(sectionNum):
print("sectionId:",sectionId)
for i in range(aicNum):
print("bn2 start ", result[16 + aicNum * faSize * sectionId + faSize * i + 0])
print("m start ", result[16 + aicNum * faSize * sectionId + faSize * i + 1])
print("s2 start ", result[16 + aicNum * faSize * sectionId + faSize * i + 2])
print("bn2 end ", result[16 + aicNum * faSize * sectionId + faSize * i + 3])
print("m end ", result[16 + aicNum * faSize * sectionId + faSize * i + 4])
print("s2 end ", result[16 + aicNum * faSize * sectionId + faSize * i + 5])
print("first fd data workspace idx ", result[16 + aicNum * faSize * sectionId + faSize * i + 6])
for i in range(aivNum):
print("fd bn2 idx ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 0])
print("fd m idx ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 1])
print("fd workspace idx ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 2])
print("fd workspace num ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 3])
print("m start ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 4])
print("m num ", result[16 + aicNum * faSize * sectionNum + aivNum * fdSize * sectionId + fdSize * i + 5])