import itertools
import torch
import torch_npu
from testcases import ENABLED_PARAMS
import check_valid_param
import prologv2_no_quant_pa_bsnd
import pytest
DEBUG_ON = 0
for _, params in enumerate(ENABLED_PARAMS):
for key, value in params.items():
locals()[f"param_{key}"] = value
param_names = [
"batch_size", "He", "Hcq", "Hckv", "q_head_num",
"kv_head_num", "head_dim", "rope_head_dim", "q_seq", "kv_seq",
"block_size", "input_layout", "cache_mode", "cq_epsilon" , "ckv_epsilon" , "dtype"
]
param_values = [
locals()["param_batch_size"],
locals()["param_He"],
locals()["param_Hcq"],
locals()["param_Hckv"],
locals()["param_q_head_num"],
locals()["param_kv_head_num"],
locals()["param_head_dim"],
locals()["param_rope_head_dim"],
locals()["param_q_seq"],
locals()["param_kv_seq"],
locals()["param_block_size"],
locals()["param_input_layout"],
locals()["param_cache_mode"],
locals()["param_cq_epsilon"],
locals()["param_ckv_epsilon"],
locals()["param_dtype"]
]
locals()["param_combinations"] = []
for combo in itertools.product(*param_values):
param_dict = dict(zip(param_names, combo))
locals()["param_combinations"].append(param_dict)
@pytest.mark.ci
@pytest.mark.parametrize("param_combinations", locals()["param_combinations"])
def test_mla_prolog_v2(param_combinations):
batch_size = param_combinations['batch_size']
He = param_combinations['He']
Hcq = param_combinations['Hcq']
Hckv = param_combinations['Hckv']
q_head_num = param_combinations['q_head_num']
kv_head_num = param_combinations['kv_head_num']
head_dim = param_combinations['head_dim']
rope_head_dim = param_combinations['rope_head_dim']
q_seq = param_combinations['q_seq']
kv_seq = param_combinations['kv_seq']
block_size = param_combinations['block_size']
input_layout = param_combinations['input_layout']
cache_mode = param_combinations['cache_mode']
cq_epsilon = param_combinations['cq_epsilon']
ckv_epsilon = param_combinations['ckv_epsilon']
dtype = param_combinations['dtype']
torch_npu.npu.set_device(0)
try:
test_data = (
batch_size, He, Hcq, Hckv, q_head_num, kv_head_num, head_dim, rope_head_dim,
q_seq, kv_seq, block_size, input_layout, cache_mode, cq_epsilon, ckv_epsilon, dtype
)
check_valid_param.validate_config(test_data)
except ValueError as e:
pytest.skip(f"参数校验失败: {e}")
if input_layout == "BSH":
test_data = batch_size, He, Hcq, Hckv, q_head_num, kv_head_num, head_dim, rope_head_dim, \
q_seq, kv_seq, block_size, input_layout, cache_mode, cq_epsilon, ckv_epsilon, dtype
expect, result = prologv2_no_quant_pa_bsnd.test_prologv2_no_quant(test_data)
check_valid_param.check_result(expect, result)