import pytest
import asyncio
import logging
from akg_agents.server.job_manager import get_job_manager
from akg_agents.core.worker.manager import get_worker_manager
from akg_agents.core.worker.local_worker import LocalWorker
from akg_agents.core.async_pool.device_pool import DevicePool
from ..utils import get_kernelbench_task_desc, add_op_prefix, get_kernelbench_op_name
logger = logging.getLogger(__name__)
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
async def test_e2e_service_flow_cuda_real():
"""
真实环境下的端到端服务化测试:
Client(模拟) -> ServerJobManager -> WorkerManager -> LocalWorker -> 真实执行
前置条件:
1. 需要有 NVIDIA GPU 环境
2. 需要配置 AKG_AGENTS_API_KEY (用于 LLM 代码生成)
"""
worker_manager = get_worker_manager()
cuda_pool = DevicePool([0])
cuda_worker = LocalWorker(cuda_pool, backend="cuda")
await worker_manager.register(cuda_worker, backend="cuda", arch="a100", capacity=1)
job_manager = get_job_manager()
try:
framework = "torch"
benchmark_names = get_kernelbench_op_name([0], framework=framework)
if benchmark_names:
op_base_name = benchmark_names[0]
task_desc = get_kernelbench_task_desc(op_base_name, framework=framework)
op_name = add_op_prefix(op_base_name, benchmark="KernelBench")
else:
raise ValueError("No benchmark found")
except Exception as e:
logger.warning(f"Failed to load KernelBench task: {e}. Using fallback task.")
op_name = "relu_service_test"
task_desc = """
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return torch.relu(x)
def get_inputs():
return [torch.randn(16, 1024).cuda()]
def get_init_inputs():
return []
"""
logger.info(f"Submitting CUDA Job: {op_name}")
job_id = await job_manager.submit_job({
"op_name": op_name,
"task_desc": task_desc,
"job_type": "single",
"backend": "cuda",
"arch": "a100",
"dsl": "triton_cuda",
"framework": "torch",
"workflow": "coder_only_workflow"
})
logger.info(f"Job submitted: {job_id}")
timeout = 600
interval = 5
for _ in range(timeout // interval):
status = job_manager.get_job_status(job_id)
state = status.get("status")
if state in ["completed", "failed", "error"]:
logger.info(f"Job finished with state: {state}")
break
logger.info(f"Job running... status: {state}")
await asyncio.sleep(interval)
final_status = job_manager.get_job_status(job_id)
if final_status["status"] != "completed":
logger.error(f"Job failed details: {final_status}")
assert final_status["status"] == "completed", f"Job failed: {final_status.get('error')}"
assert final_status["result"]["success"] is True