"""Adaptive Search 端到端系统测试 — ReLU (Triton Ascend)"""
import pytest
from akg_agents.op.adaptive_search import adaptive_search
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_local_worker
from akg_agents.utils.task_label import resolve_task_label
from ..utils import get_device_id
RELU_TASK_DESC = """\
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(x)
batch_size = 16
dim = 16384
def get_inputs():
x = torch.rand(batch_size, dim, device='npu')
return [x]
def get_init_inputs():
return []
"""
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.asyncio
async def test_adaptive_search_relu_ascend910b4():
"""端到端测试: adaptive_search 生成 ReLU (triton_ascend, ascend910b4)"""
op_name = "akg_agents_relu"
dsl = "triton_ascend"
framework = "torch"
backend = "ascend"
arch = "ascend910b4"
device_id = get_device_id()
await register_local_worker([device_id], backend=backend, arch=arch)
config = load_config(dsl=dsl, backend=backend)
config["task_label"] = resolve_task_label(op_name=op_name, parallel_index=1)
config["max_step"] = 5
result = await adaptive_search(
op_name=op_name,
task_desc=RELU_TASK_DESC,
dsl=dsl,
framework=framework,
backend=backend,
arch=arch,
config=config,
max_concurrent=1,
initial_task_count=1,
max_total_tasks=2,
)
assert isinstance(result, dict)
assert "total_submitted" in result
assert "total_completed" in result
assert "total_success" in result
assert "success_rate" in result
assert "elapsed_time" in result
assert "stop_reason" in result
assert result["total_submitted"] >= 1
if result["total_success"] > 0:
best = result.get("best_implementations", [])
assert len(best) > 0
assert "impl_code" in best[0]
assert result["success_rate"] > 0