import pytest
import torch
import importlib.util
import sys
from pathlib import Path
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
from akg_agents.utils.common_utils import create_log_dir
from akg_agents.op.config.config_validator import load_config
from ..utils import get_device_id
device_id = get_device_id()
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
@pytest.mark.parametrize("op_name", [
"assign_extend_cache_locs",
"assign_req_to_token_pool",
"compute_position",
"fused_qkvzba_split_reshape_cat",
"get_mla_kv_buffer",
"merge_state_triton",
"moe_align_block_size_triton",
"moe_sum_reduce_triton",
"set_mla_kv_buffer",
"set_mla_kv_scale_buffer",
"write_req_to_token_pool",
"triton_tanh",
"get_last_loc",
"merge_state_kernel",
"prefill_attention_fwd_kernel",
"extend_attention_fwd_kernel",
"decode_attention_fwd_kernel_stage1",
"decode_attention_fwd_kernel_stage2",
"decode_grouped_attention_fwd_kernel_stage1",
"_fwd_grouped_kernel_stage1_rope",
"add_tree_reduce_u64",
"chunked_sgmv_lora_expand",
"chunked_sgmv_lora_shrink",
"fmix32",
"hash_tiles32_kernel_blocked",
"rotl32",
"sgemm_lora_a",
"sgemm_lora_b",
"qkv_lora_b",
"gate_up_lora_b",
])
async def test_sglang_verifier_a100(op_name):
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./benchmark/akg_kernels_bench/thirdparty/sglang/{op_name}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
framework_code = f.read()
kernel_code = framework_code.replace("ModelSGLang", "ModelNew")
log_dir = create_log_dir(f'{op_name}_sglang_{framework}_{backend}_{arch}_{dsl}_test')
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=framework_code,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=config
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败: {error_log}"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.parametrize("op_name", [
"align_evict_mask_to_page_size",
"alloc_decode",
"alloc_extend",
"assign_draft_cache_locs_page_size_1",
"assign_draft_cache_locs",
"copy_all_layer_kv_cache_tiled",
"create_chunked_prefix_cache_kv_indices",
"create_extend_after_decode_spec_info",
"fill_accepted_out_cache_loc",
"fill_new_verified_id",
"filter_finished_cache_loc_kernel",
"generate_draft_decode_kv_indices",
"get_target_cache_loc",
])
def test_sglang_class_method_no_reference_a100(op_name):
"""
无标杆验证:针对 class_method kernels,只检查输出是否包含 nan 或 inf
"""
op_task_file = f"./benchmark/akg_kernels_bench/thirdparty/sglang/class_method/{op_name}.py"
spec = importlib.util.spec_from_file_location(f"sglang_class_method_{op_name}", op_task_file)
module = importlib.util.module_from_spec(spec)
sys.modules[f"sglang_class_method_{op_name}"] = module
spec.loader.exec_module(module)
Model = module.Model
get_inputs = module.get_inputs
get_init_inputs = module.get_init_inputs
device = torch.device(f"cuda:{device_id}" if device_id >= 0 else "cuda:0")
init_params = get_init_inputs()
inputs = get_inputs()
inputs = [inp.to(device) if isinstance(inp, torch.Tensor) else inp for inp in inputs]
model = Model(*init_params)
if hasattr(model, 'to'):
model = model.to(device)
output = model(*inputs)
if not isinstance(output, (list, tuple)):
output = [output]
for i, out in enumerate(output):
if isinstance(out, torch.Tensor):
nan_count = torch.isnan(out).sum().item()
if nan_count > 0:
raise AssertionError(
f"{op_name}: 输出 {i} 包含 {nan_count} 个 NaN 值 "
f"(shape: {out.shape}, dtype: {out.dtype})"
)
inf_count = torch.isinf(out).sum().item()
if inf_count > 0:
raise AssertionError(
f"{op_name}: 输出 {i} 包含 {inf_count} 个 Inf 值 "
f"(shape: {out.shape}, dtype: {out.dtype})"
)
print(f"{op_name}: 验证通过,输出无 NaN/Inf")