"""ST test for SOL-ExecBench + KernelGen-Only workflow on Ascend NPU."""
import json
import os
import tempfile
import pytest
from akg_agents.core.worker.manager import register_local_worker, register_remote_worker
from akg_agents.op.config.config_validator import load_config
from akg_agents.op.langgraph_op.task import LangGraphTask as AIKGTask
from akg_agents.utils.environment_check import check_env_for_task
from ..utils import get_device_id
DEFINITION_JSON = json.dumps({
"name": "000_mock_relu",
"description": "A simple ReLU operation for testing SOL-ExecBench integration.",
"axes": {
"batch_size": {"type": "var", "description": "Batch size"},
"seq_len": {"type": "var", "description": "Sequence length"},
"hidden_size": {"type": "const", "value": 128, "description": "Hidden dimension size"},
},
"custom_inputs_entrypoint": None,
"inputs": {
"x": {"shape": ["batch_size", "seq_len", "hidden_size"], "dtype": "float32", "description": "Input tensor"}
},
"outputs": {
"out": {"shape": ["batch_size", "seq_len", "hidden_size"], "dtype": "float32", "description": "Output tensor"}
},
"reference": (
"import torch\n\n@torch.no_grad()\ndef run(x: torch.Tensor)"
" -> torch.Tensor:\n return torch.nn.functional.relu(x)\n"
),
}, indent=4, ensure_ascii=False)
REFERENCE_PY = """\
import torch
@torch.no_grad()
def run(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.relu(x)
"""
WORKLOAD_JSONL = (
'{"uuid": "1", "axes": {"batch_size": 2, "seq_len": 64}, '
'"inputs": {"x": {"type": "random"}}, "tolerance": {"max_atol": 1e-05, "max_rtol": 1e-05}}\n'
'{"uuid": "2", "axes": {"batch_size": 4, "seq_len": 128}, '
'"inputs": {"x": {"type": "random"}}, "tolerance": {"max_atol": 1e-05, "max_rtol": 1e-05}}\n'
)
TASK_DESC = (
"请实现一个 Triton Ascend 算子。\n\n"
f"## definition.json\n```json\n{DEFINITION_JSON}\n```\n\n"
f"## reference.py\n```python\n{REFERENCE_PY}```\n\n"
"## workload 示例(共 2 组,以下为第 1 组)\n"
'```json\n{"uuid": "1", "axes": {"batch_size": 2, "seq_len": 64}, '
'"inputs": {"x": {"type": "random"}}, "tolerance": {"max_atol": 1e-05, "max_rtol": 1e-05}}\n```\n\n'
"注意:请使用 Triton 编写 kernel,并将其封装在 ModelNew 类的 forward 方法中。"
)
@pytest.mark.level1
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_sol_kernelgen_triton_ascend():
"""SOL-ExecBench 端到端:KernelGen-Only workflow + Triton Ascend NPU"""
framework = "torch"
dsl = "triton_ascend"
backend = "ascend"
arch = "ascend910b4"
worker_mode = os.getenv("AKG_ST_WORKER_MODE", "local").strip().lower()
remote_url = os.getenv("AKG_ST_WORKER_URL", "http://127.0.0.1:19001").strip()
device_id = get_device_id()
config = load_config(
config_path="./python/akg_agents/op/config/triton_ascend_kernelgen_config.yaml"
)
with tempfile.TemporaryDirectory(prefix="sol_st_") as sol_dir:
for name, content in [
("definition.json", DEFINITION_JSON),
("reference.py", REFERENCE_PY),
("workload.jsonl", WORKLOAD_JSONL),
]:
with open(os.path.join(sol_dir, name), "w", encoding="utf-8") as f:
f.write(content)
config["sol_problem_dir"] = sol_dir
config["verify_timeout"] = 300
check_env_for_task(framework, backend, dsl, config, is_remote=(worker_mode == "remote"))
if worker_mode == "remote":
await register_remote_worker(backend=backend, arch=arch, worker_url=remote_url)
else:
await register_local_worker([device_id], backend=backend, arch=arch)
task = AIKGTask(
op_name="_relu",
task_desc=TASK_DESC,
task_id="st_sol_kernelgen_ascend_001",
backend=backend,
arch=arch,
dsl=dsl,
config=config,
framework=framework,
workflow="kernelgen_only_workflow",
bench_type="sol",
)
_, success, final_state = await task.run()
assert success, (
f"SOL KernelGen Ascend 端到端执行失败: "
f"{final_state.get('verifier_error') or final_state.get('error_message')}"
)