"""Autoresearch 端到端系统测试 — ReLU
算子: torch.relu(x)
Shape: (11, 37, 8191)
选型理由:
- ReLU 是最简单的 elementwise 算子, 任何模型都能写对, 确保 CI 稳定
- (11, 37, 8191) 全是质数, 8191 = 2^13 - 1, 任何 BLOCK_SIZE 都需要 mask 处理
- 8191 足够大, 容易爆 UB (unified buffer) 如果 BLOCK_SIZE 选得太大
- 总元素 11×37×8191 ≈ 3.3M, 不会 OOM 但足以暴露性能问题
自动检测设备: 有 NPU 用 triton_ascend, 有 GPU 用 triton_cuda, 都没有则跳过.
性能门禁: speedup > 1.0x (kernel 必须不比 ref 慢)
"""
import torch
import pytest
from akg_agents.op.langgraph_op.task import LangGraphTask
from akg_agents.op.config.config_validator import load_config
from akg_agents.core.worker.manager import register_local_worker
from akg_agents.utils.task_label import resolve_task_label
from ..utils import get_device_id
def _detect_backend():
"""Detect available accelerator. Returns (dsl, backend, arch, device_str) or None."""
try:
import torch_npu
if torch.npu.is_available():
return "triton_ascend", "ascend", "ascend910b4", "npu"
except ImportError:
pass
if torch.cuda.is_available():
return "triton_cuda", "cuda", "a100", "cuda"
return None
_backend_info = _detect_backend()
pytestmark = pytest.mark.skipif(
_backend_info is None,
reason="No NPU or GPU available — skipping autoresearch ST",
)
def _make_task_desc(device_str: str) -> str:
return f'''\
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(x)
def get_inputs():
return [torch.randn(11, 37, 8191, device='{device_str}')]
def get_init_inputs():
return []
'''
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.asyncio
async def test_autoresearch_relu():
"""端到端测试: autoresearch 优化 ReLU (自动检测后端)
覆盖完整链路: preflight → seed 三层验证 → AgentLoop → eval → result
"""
dsl, backend, arch, device_str = _backend_info
op_name = "akg_agents_relu"
device_id = get_device_id()
await register_local_worker([device_id], backend=backend, arch=arch)
config = load_config(dsl=dsl, backend=backend)
config["task_label"] = resolve_task_label(op_name=op_name, parallel_index=1)
config["max_step"] = 10
config["gen_retries"] = 10
task = LangGraphTask(
op_name=op_name,
task_desc=_make_task_desc(device_str),
task_id="autoresearch_test_001",
backend=backend,
arch=arch,
dsl=dsl,
config=config,
framework="torch",
workflow="autoresearch",
)
result_op_name, success, final_state = await task.run()
assert result_op_name == op_name
assert isinstance(final_state, dict)
assert success, f"autoresearch workflow failed: {final_state.get('verifier_error', '')}"
assert final_state.get("verifier_result") is True
coder_code = final_state.get("coder_code", "")
assert len(coder_code) > 0
profile_res = final_state.get("profile_res", {})
assert "gen_time" in profile_res
gen_time = profile_res["gen_time"]
base_time = profile_res.get("base_time")
if base_time and base_time > 0 and gen_time and gen_time > 0:
speedup = base_time / gen_time
assert speedup > 1.0, (
f"kernel slower than ref: {gen_time:.1f}us vs ref {base_time:.1f}us "
f"(speedup={speedup:.2f}x, required >1.0x)"
)