""" Attention 相关用例 Golden 生成逻辑.
"""
import logging
from pathlib import Path
import numpy as np
from ml_dtypes import bfloat16
import math
import torch
from golden.net.deepseekv3.mla.mla_prolog_golden_v2 import gen_mla_prolog_data
from golden.op.attention_post import gen_pa_data, gen_post_data
from golden_register import GoldenRegister
fp32 = np.float32
def gen_attention_test_data(dtypes, bns2, epsilon, output_dir: Path, is_quant=False, is_nz=False, is_smooth=False, cache_mode="BNSD"):
b, n, s2, block_size, n_tile = bns2
n2, s = 1, 1
params = {
"b": b,
"s": s,
"s2": s2,
"h": 7168,
"num_heads": n,
"n2": n2,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"kv_lora_rank": 512,
"v_head_dim": 128,
}
quant_choice = (False, is_quant)
print("gen_attention_test_data cache_mode is ", cache_mode)
q_out, q_rope_out, kv_cache_out, kr_cache_out = \
gen_mla_prolog_data(params, dtypes, epsilon, output_dir, quant_choice, is_nz, is_smooth, block_size, cache_mode)
q_out = q_out.reshape(b, n, s, params["kv_lora_rank"])
q_rope_out = q_rope_out.reshape(b, n, s, params["qk_rope_head_dim"])
pa_out = gen_pa_data(output_dir, params, dtypes[0], q_out, q_rope_out, kv_cache_out, kr_cache_out,
block_size, n_tile, is_nz)
post_out = gen_post_data(output_dir, params, dtypes[0], pa_out, is_quant, is_nz)
return post_out
@GoldenRegister.reg_golden_func(
case_names=[
"DynamicAttention.dynamic_attention_low",
"DynamicAttention.dynamic_attention_high",
"DynamicAttention.low_latency_quant_smooth_nz",
"DynamicAttention.high_throughput_quant_smooth_nz",
]
)
def gen_attention_date(case_name: str, output: Path) -> bool:
print("gen_attention_date-------------------------- ")
if case_name == "DynamicAttention.dynamic_attention_low":
gen_attention_test_data((np.float16, np.float16), (4, 32, 256, 256, 32), 1e-5, output, False, False, False)
elif case_name == "DynamicAttention.dynamic_attention_high":
gen_attention_test_data((np.float16, np.float16), (32, 128, 4096, 128, 128), 1e-5, output, False, False, False)
elif case_name == "DynamicAttention.low_latency_quant_smooth_nz":
gen_attention_test_data((np.float16, np.float16), (4, 32, 256, 128, 32), 1e-5, output, True, True, True)
elif case_name == "DynamicAttention.high_throughput_quant_smooth_nz":
gen_attention_test_data((np.float16, np.float16), (32, 128, 4096, 4096, 128), 1e-5, output, True, True, True, "PA_NZ")
else:
logging.error("Can't get func to gen golden, Case(%s)", case_name)
return False
return True
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
gen_attention_date("DynamicAttention.dynamic_attention_low", "./")