import pytest
import asyncio
import tarfile
import io
import os
import textwrap
import uuid
import httpx
from akg_agents.core.worker.local_worker import LocalWorker
from akg_agents.core.async_pool.device_pool import DevicePool
from akg_agents.op.verifier.kernel_verifier import KernelVerifier
from akg_agents.op.config.config_validator import load_config
from ..utils import get_device_id
device_id = get_device_id()
@pytest.mark.asyncio
async def test_local_worker_verify_success():
"""Test LocalWorker verification success flow."""
op_name = "dummy_op"
script_content = """
import os
import sys
print("Starting verification...")
print(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
print(f"ASCEND_VISIBLE_DEVICES={os.environ.get('ASCEND_VISIBLE_DEVICES')}")
print("Verification successful!")
sys.exit(0)
"""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode='w') as tar_file:
info = tarfile.TarInfo(name=f"verify_{op_name}.py")
script_bytes = script_content.encode('utf-8')
info.size = len(script_bytes)
tar_file.addfile(tarinfo=info, fileobj=io.BytesIO(script_bytes))
package_data = tar_buffer.getvalue()
device_pool = DevicePool([0, 1])
worker = LocalWorker(device_pool, backend="cuda")
task_id = "test_task_001"
success, log, artifacts = await worker.verify(package_data, task_id, op_name, timeout=10)
assert success is True
assert "Verification successful!" in log
assert "CUDA_VISIBLE_DEVICES" in log
@pytest.mark.asyncio
async def test_local_worker_verify_failure():
"""Test LocalWorker verification failure flow."""
op_name = "dummy_fail_op"
script_content = """
import sys
print("Verification failed intentionally.")
sys.exit(1)
"""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode='w') as tar_file:
info = tarfile.TarInfo(name=f"verify_{op_name}.py")
script_bytes = script_content.encode('utf-8')
info.size = len(script_bytes)
tar_file.addfile(tarinfo=info, fileobj=io.BytesIO(script_bytes))
package_data = tar_buffer.getvalue()
device_pool = DevicePool([0])
worker = LocalWorker(device_pool, backend="cuda")
success, log, artifacts = await worker.verify(package_data, "test_task_002", op_name, timeout=10)
assert success is False
assert "Verification failed intentionally" in log
@pytest.mark.asyncio
async def test_local_worker_timeout():
"""Test LocalWorker timeout."""
op_name = "dummy_timeout_op"
script_content = """
import time
import sys
print("Sleeping...")
time.sleep(5)
sys.exit(0)
"""
tar_buffer = io.BytesIO()
with tarfile.open(fileobj=tar_buffer, mode='w') as tar_file:
info = tarfile.TarInfo(name=f"verify_{op_name}.py")
script_bytes = script_content.encode('utf-8')
info.size = len(script_bytes)
tar_file.addfile(tarinfo=info, fileobj=io.BytesIO(script_bytes))
package_data = tar_buffer.getvalue()
device_pool = DevicePool([0])
worker = LocalWorker(device_pool, backend="cuda")
success, log, artifacts = await worker.verify(package_data, "test_task_003", op_name, timeout=1)
assert success is False
assert "timed out" in log
def test_local_worker_get_doc(monkeypatch):
device_pool = DevicePool([0])
worker = LocalWorker(device_pool, backend="ascend")
monkeypatch.setattr(
"akg_agents.core.worker.local_worker.load_triton_ascend_api_docs",
lambda: "doc",
)
assert asyncio.run(worker.get_doc("triton_ascend_api")) == "doc"
def test_worker_server_docs_endpoint(monkeypatch):
from akg_agents.worker import server as worker_server
class StubWorker:
async def get_doc(self, doc_name: str) -> str:
assert doc_name == "triton_ascend_api"
return "# remote docs"
monkeypatch.setattr(worker_server, "worker", StubWorker())
async def _run():
transport = httpx.ASGITransport(app=worker_server.app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
return await client.get(
"/api/v1/docs/triton_ascend_api",
)
response = asyncio.run(_run())
assert response.status_code == 200
assert response.json()["content"] == "# remote docs"
def test_remote_worker_get_doc(monkeypatch):
from akg_agents.core.worker.remote_worker import RemoteWorker
class StubResponse:
def raise_for_status(self):
return None
def json(self):
return {"content": "# remote docs"}
class StubAsyncClient:
def __init__(self, *args, **kwargs):
self.timeout = kwargs.get("timeout")
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, params=None):
assert url == "http://worker/api/v1/docs/triton_ascend_api"
assert params is None
return StubResponse()
monkeypatch.setattr(
"akg_agents.core.worker.remote_worker.httpx.AsyncClient",
StubAsyncClient,
)
worker = RemoteWorker("http://worker")
assert asyncio.run(worker.get_doc("triton_ascend_api")) == "# remote docs"
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
async def test_local_worker_real_triton_cuda_relu():
"""Test LocalWorker with real Torch CUDA Triton ReLU verification."""
op_name = "relu"
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = os.path.expanduser("~/.akg_agents/test_logs/worker_test")
os.makedirs(log_dir, exist_ok=True)
config['log_dir'] = log_dir
device_pool = DevicePool([device_id])
worker = LocalWorker(device_pool, backend=backend)
task_id = str(uuid.uuid4())[:8]
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id=task_id,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name="ModelNew",
config=config,
worker=worker
)
task_info = {"coder_code": kernel_code}
success, error_log = await verifier.run(task_info, current_step=0, device_id=0)
assert success, f"验证失败: {error_log}"
print(f"✓ LocalWorker 成功验证 {op_name} ({framework}/{dsl}/{backend})")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
async def test_local_worker_real_triton_cuda_relu_profile():
"""Test LocalWorker with real Torch CUDA Triton ReLU profiling."""
op_name = "relu"
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = os.path.expanduser("~/.akg_agents/test_logs/worker_profile_test")
os.makedirs(log_dir, exist_ok=True)
config['log_dir'] = log_dir
device_pool = DevicePool([device_id])
worker = LocalWorker(device_pool, backend=backend)
task_id = str(uuid.uuid4())[:8]
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id=task_id,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name="ModelNew",
config=config,
worker=worker
)
task_info = {"coder_code": kernel_code}
success, error_log = await verifier.run(task_info, current_step=0, device_id=0)
assert success, f"验证失败: {error_log}"
profile_settings = {
"run_times": 10,
"warmup_times": 2
}
result = await verifier.run_profile(task_info, current_step=0, device_id="0", profile_settings=profile_settings)
assert result is not None
assert 'gen_time' in result
assert 'base_time' in result
assert 'speedup' in result
assert result['gen_time'] < float('inf'), "生成代码时间应该是有效值"
assert result['base_time'] > 0, "基准代码时间应该大于0"
print(f"✓ LocalWorker 成功 profile {op_name}")
print(f" Base time: {result['base_time']:.2f} us")
print(f" Gen time: {result['gen_time']:.2f} us")
print(f" Speedup: {result['speedup']:.2f}x")
@pytest.mark.level0
@pytest.mark.torch
@pytest.mark.triton
@pytest.mark.cuda
@pytest.mark.a100
@pytest.mark.asyncio
async def test_remote_worker_via_local_service():
"""Test RemoteWorker by simulating the Worker Service with LocalWorker.
This test simulates the remote worker scenario:
1. Worker Service (FastAPI) wraps LocalWorker
2. RemoteWorker (HTTP client) calls the service
3. Service delegates to LocalWorker
Note: This test directly uses LocalWorker to simulate what RemoteWorker would do.
In production, RemoteWorker would make HTTP calls to the Worker Service.
"""
op_name = "relu"
framework = "torch"
dsl = "triton_cuda"
backend = "cuda"
arch = "a100"
config = load_config(dsl, backend=backend)
op_task_file = f"./tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
with open(op_task_file, "r", encoding="utf-8") as f:
op_task_str = textwrap.dedent(f.read())
kernel_path = f"./tests/op/resources/{op_name}_op/{op_name}_{dsl}.py"
with open(kernel_path, "r", encoding="utf-8") as f:
kernel_code = f.read()
log_dir = os.path.expanduser("~/.akg_agents/test_logs/remote_worker_test")
os.makedirs(log_dir, exist_ok=True)
config['log_dir'] = log_dir
device_pool = DevicePool([device_id])
worker = LocalWorker(device_pool, backend=backend)
task_id = str(uuid.uuid4())[:8]
verifier = KernelVerifier(
op_name=op_name,
framework_code=op_task_str,
task_id=task_id,
framework=framework,
dsl=dsl,
backend=backend,
arch=arch,
impl_func_name="ModelNew",
config=config,
worker=worker
)
task_info = {"coder_code": kernel_code}
success, error_log = await verifier.run(task_info, current_step=0, device_id=0)
assert success, f"RemoteWorker 验证失败: {error_log}"
print(f"✓ RemoteWorker (via LocalWorker) 成功验证 {op_name}")
print(f" Note: This test uses LocalWorker to simulate RemoteWorker behavior")
print(f" In production, RemoteWorker would call Worker Service via HTTP")