import os
import pytest
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
from akg_agents.utils.common_utils import create_log_dir
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_remote_worker, get_worker_manager
cuda_worker_url = os.environ.get("CUDA_WORKER_URL", "http://localhost:9001")
ascend_worker_url = os.environ.get("ASCEND_WORKER_URL", "http://localhost:9001")
def get_device_id():
return 3
device_id = get_device_id()
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
@pytest.mark.parametrize("op_name", [
"l2norm_fwd",
"merge_attn_states",
"reshape_and_cache_flash",
"fwd_diag_kernel",
"fwd_kv_reduce",
"fwd_none_diag_kernel",
"linear_attn_decode_kernel",
"log_softmax_kernel",
"mean_kernel",
"rms_norm_kernel",
"gumbel_sample_kernel",
"topk_log_softmax_kernel",
"ranks_kernel",
"append_block_ids_kernel",
"gather_block_tables_kernel",
"prepare_prefill_inputs_kernel",
"prepare_pos_seq_lens_kernel",
"combine_sampled_and_draft_tokens_kernel",
"post_update_kernel",
"layernorm_fn",
"rms_norm_gated_triton",
"kda_gate_fwd_kernel",
"count_expert_num_tokens",
"pack_bitmatrix",
"compute_identity_kernel",
"pack_seq_triton",
"unpack_seq_triton",
"triton_mrope",
])
async def test_vllm_triton_verifier_a100(op_name):
"""
测试vLLM的Triton算子精度
验证Model(原生实现)和ModelVLLM(vLLM优化实现)的输出一致性
"""
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
try:
await register_remote_worker(
backend="cuda",
arch="a100",
worker_url=cuda_worker_url
)
print(f" ✓ CUDA Worker 注册成功")
except Exception as e:
print(f" ✗ CUDA Worker 注册失败: {e}")
return False
worker_manager = get_worker_manager()
print()
cuda_config = load_config("triton_cuda", backend="cuda")
cuda_worker = await worker_manager.select(backend="cuda", arch="a100")
if not cuda_worker:
print(" ✗ 无法获取 CUDA Worker")
return False
op_task_file = f"./benchmark/akg_kernels_bench/thirdparty/vllm/triton_ops/{op_name}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
framework_code = f.read()
kernel_code = framework_code.replace("ModelVLLM", "ModelNew")
log_dir = create_log_dir(f'{op_name}_vllm_{framework}_{backend}_{arch}_{dsl}_test')
impl_func_name = "ModelNew"
verifier = KernelVerifier(
op_name=op_name,
framework_code=framework_code,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name=impl_func_name,
config=cuda_config,
worker=cuda_worker
)
task_info = {}
task_info["coder_code"] = kernel_code
result, error_log = await verifier.run(task_info, device_id=device_id)
assert result, f"验证失败 [{op_name}]: {error_log}"