import pytest
import asyncio
from typing import Tuple, Dict, Any
from akg_agents.core.worker.interface import WorkerInterface
from akg_agents.core.worker.manager import WorkerManager
class MockWorker(WorkerInterface):
def __init__(self, name):
self.name = name
async def verify(self, package_data, task_id: str, op_name: str, timeout: int = 300) -> Tuple[bool, str, Dict[str, Any]]:
return True, "Success", {}
async def profile(self, package_data: bytes, task_id: str, op_name: str, profile_settings: Dict[str, Any]) -> Dict[str, Any]:
return {}
async def generate_reference(self, package_data: bytes, task_id: str, op_name: str, timeout: int = 120) -> Tuple[bool, str, bytes]:
return True, "Reference generated", b''
async def profile_single_task(self, package_data: bytes, task_id: str, op_name: str,
profile_settings: Dict[str, Any]) -> Dict[str, Any]:
return {'time_us': None, 'success': False, 'log': 'mock'}
async def get_doc(self, doc_name: str) -> str:
return f"{self.name}:{doc_name}"
@pytest.mark.asyncio
async def test_worker_manager_basic_flow():
"""测试基本的注册、选择、释放流程"""
manager = WorkerManager()
worker1 = MockWorker("worker1")
await manager.register(worker1, backend="cuda", arch="a100", capacity=2)
selected = await manager.select(backend="cuda", arch="a100")
assert selected is worker1
status = await manager.get_status()
assert status[0]["load"] == 1
selected2 = await manager.select(backend="cuda")
assert selected2 is worker1
status = await manager.get_status()
assert status[0]["load"] == 2
await manager.release(selected)
status = await manager.get_status()
assert status[0]["load"] == 1
await manager.release(selected2)
status = await manager.get_status()
assert status[0]["load"] == 0
@pytest.mark.asyncio
async def test_worker_manager_load_balancing():
"""测试负载均衡逻辑"""
manager = WorkerManager()
w1 = MockWorker("w1")
w2 = MockWorker("w2")
await manager.register(w1, "cuda", "a100", capacity=2)
await manager.register(w2, "cuda", "a100", capacity=2)
s1 = await manager.select("cuda")
assert s1 is w1
s2 = await manager.select("cuda")
assert s2 is w2
s3 = await manager.select("cuda")
assert s3 is w1
s4 = await manager.select("cuda")
assert s4 is w2
await manager.release(s1)
s5 = await manager.select("cuda")
assert s5 is w1
@pytest.mark.asyncio
async def test_worker_manager_filtering():
"""测试筛选逻辑"""
manager = WorkerManager()
w_cuda = MockWorker("cuda")
w_ascend = MockWorker("ascend")
await manager.register(w_cuda, "cuda", "a100")
await manager.register(w_ascend, "ascend", "910b")
assert (await manager.select("cuda")) is w_cuda
assert (await manager.select("ascend")) is w_ascend
assert (await manager.select("cuda", arch="a100")) is w_cuda
assert (await manager.select("cuda", arch="v100")) is None
await manager.release(w_cuda)
await manager.release(w_cuda)
await manager.release(w_ascend)
@pytest.mark.asyncio
async def test_worker_manager_tags():
"""测试标签筛选"""
manager = WorkerManager()
w_remote = MockWorker("remote")
w_local = MockWorker("local")
await manager.register(w_remote, "cuda", "a100", tags={"remote", "fast"})
await manager.register(w_local, "cuda", "a100", tags={"local"})
assert (await manager.select("cuda", tags={"remote"})) is w_remote
assert (await manager.select("cuda", tags={"local"})) is w_local
assert (await manager.select("cuda", tags={"fast", "remote"})) is w_remote
assert (await manager.select("cuda", tags={"fast", "local"})) is None