"""OpTaskBuilder ST 测试。
当前文件只保留一个最简单的 ReLU 场景用例。
"""
import os
import pytest
from akg_agents.core.worker.manager import register_local_worker
from akg_agents.op.agents.op_task_builder import OpTaskBuilder
from akg_agents.op.langgraph_op.op_task_builder_state import OpTaskBuilderStatus
from ..utils import get_device_id
os.environ["AKG_AGENTS_STREAM_OUTPUT"] = "on"
FRAMEWORK = "torch"
BACKEND = "ascend"
ARCH = "ascend910b4"
DSL = "triton_ascend"
DEVICE_ID = get_device_id()
async def ensure_worker_registered():
"""确保本地 Ascend worker 已注册。"""
from akg_agents.core.worker.manager import get_worker_manager
worker_manager = get_worker_manager()
if not await worker_manager.has_worker(backend=BACKEND, arch=ARCH):
await register_local_worker([DEVICE_ID], backend=BACKEND, arch=ARCH)
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.ascend
@pytest.mark.ascend910b4
@pytest.mark.asyncio
async def test_simple_relu_request():
"""明确的 ReLU 需求应返回可继续处理的结果。"""
await ensure_worker_registered()
builder = OpTaskBuilder()
state = {
"user_input": "我需要一个ReLU激活函数的算子,输入是16x16384的张量",
"framework": FRAMEWORK,
"backend": BACKEND,
"arch": ARCH,
"dsl": DSL,
"iteration": 0,
"max_iterations": 5,
"max_check_retries": 3,
"check_retry_count": 0,
"conversation_history": [],
}
result = await builder.run(state)
assert result.get("status") in [
OpTaskBuilderStatus.READY,
OpTaskBuilderStatus.NEED_CLARIFICATION,
]
if result.get("status") == OpTaskBuilderStatus.READY:
generated_task_desc = result.get("generated_task_desc", "")
assert result.get("op_name")
assert generated_task_desc
assert "class Model" in generated_task_desc