"""Unit tests for Backend Adapters."""
import os
import pytest
from akg_agents.op.verifier.adapters.factory import get_backend_adapter
class TestBackendAdapterCuda:
"""Test CUDA Backend Adapter."""
def test_setup_environment(self):
"""Test environment setup."""
adapter = get_backend_adapter("cuda")
original_value = os.environ.get('CUDA_VISIBLE_DEVICES')
try:
adapter.setup_environment(0, "a100")
assert os.environ.get('CUDA_VISIBLE_DEVICES') == "0"
finally:
if original_value is None:
os.environ.pop('CUDA_VISIBLE_DEVICES', None)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = original_value
def test_get_device_string(self):
"""Test device string generation."""
adapter = get_backend_adapter("cuda")
device_str = adapter.get_device_string(0)
assert device_str == "cuda:0"
def test_validate_arch(self):
"""Test architecture validation."""
adapter = get_backend_adapter("cuda")
assert adapter.validate_arch("a100") is True
assert adapter.validate_arch("v100") is True
assert adapter.validate_arch("rtx3090") is True
assert adapter.validate_arch("invalid") is False
class TestBackendAdapterAscend:
"""Test Ascend Backend Adapter."""
def test_setup_environment(self):
"""Test environment setup."""
adapter = get_backend_adapter("ascend")
original_value = os.environ.get('DEVICE_ID')
try:
adapter.setup_environment(0, "ascend910b4")
assert os.environ.get('DEVICE_ID') == "0"
finally:
if original_value is None:
os.environ.pop('DEVICE_ID', None)
else:
os.environ['DEVICE_ID'] = original_value
def test_get_device_string(self):
"""Test device string generation."""
adapter = get_backend_adapter("ascend")
device_str = adapter.get_device_string(0)
assert device_str == "npu:0"
def test_validate_arch(self):
"""Test architecture validation."""
adapter = get_backend_adapter("ascend")
assert adapter.validate_arch("ascend910b4") is True
assert adapter.validate_arch("ascend910b1") is True
assert adapter.validate_arch("ascend310p3") is True
assert adapter.validate_arch("invalid") is False
class TestBackendAdapterCpu:
"""Test CPU Backend Adapter."""
def test_setup_environment(self):
"""Test environment setup (should be no-op)."""
adapter = get_backend_adapter("cpu")
adapter.setup_environment(0, "x86_64")
def test_get_device_string(self):
"""Test device string generation."""
adapter = get_backend_adapter("cpu")
device_str = adapter.get_device_string(0)
assert device_str == "cpu"
def test_validate_arch(self):
"""Test architecture validation."""
adapter = get_backend_adapter("cpu")
assert adapter.validate_arch("x86_64") is True
assert adapter.validate_arch("aarch64") is True
assert adapter.validate_arch("invalid") is False
class TestBackendAdapterFactory:
"""Test Backend Adapter Factory."""
def test_get_backend_adapter_cuda(self):
"""Test getting CUDA adapter."""
adapter = get_backend_adapter("cuda")
assert adapter is not None
assert adapter.__class__.__name__ == "BackendAdapterCuda"
def test_get_backend_adapter_ascend(self):
"""Test getting Ascend adapter."""
adapter = get_backend_adapter("ascend")
assert adapter is not None
assert adapter.__class__.__name__ == "BackendAdapterAscend"
def test_get_backend_adapter_cpu(self):
"""Test getting CPU adapter."""
adapter = get_backend_adapter("cpu")
assert adapter is not None
assert adapter.__class__.__name__ == "BackendAdapterCpu"
def test_get_backend_adapter_invalid(self):
"""Test getting invalid backend adapter."""
with pytest.raises(ValueError, match="Unsupported backend"):
get_backend_adapter("invalid")