"""Unit tests for Framework Adapters."""
import pytest
from akg_agents.op.verifier.adapters.factory import get_framework_adapter
class TestFrameworkAdapterTorch:
"""Test PyTorch Framework Adapter."""
def test_get_import_statements(self):
"""Test import statements generation."""
adapter = get_framework_adapter("torch")
imports = adapter.get_import_statements()
assert "import torch" in imports
assert imports.endswith("\n")
def test_get_framework_import_static(self):
"""Test framework model import for static shape."""
adapter = get_framework_adapter("torch")
imports = adapter.get_framework_import("test_op", False)
assert "from test_op_torch import Model as FrameworkModel" in imports
assert "get_inputs" in imports
assert "get_inputs_dyn_list" not in imports
def test_get_framework_import_dynamic(self):
"""Test framework model import for dynamic shape."""
adapter = get_framework_adapter("torch")
imports = adapter.get_framework_import("test_op", True)
assert "from test_op_torch import Model as FrameworkModel" in imports
assert "get_inputs_dyn_list" in imports
def test_get_device_setup_cuda(self):
"""Test device setup for CUDA."""
adapter = get_framework_adapter("torch")
code = adapter.get_device_setup_code("cuda", "a100", 0)
assert "CUDA_VISIBLE_DEVICES" in code
assert "torch.device(\"cuda\")" in code
def test_get_device_setup_ascend(self):
"""Test device setup for Ascend."""
adapter = get_framework_adapter("torch")
code = adapter.get_device_setup_code("ascend", "ascend910b4", 0)
assert "DEVICE_ID" in code
assert "torch.device(\"npu\")" in code
def test_get_process_input_code(self):
"""Test process_input code generation."""
adapter = get_framework_adapter("torch")
code = adapter.get_process_input_code("cuda", "triton_cuda")
assert "def process_input" in code
assert "x.to(device)" in code
def test_get_process_input_code_ascendc(self):
"""Test process_input code for AscendC."""
adapter = get_framework_adapter("torch")
code = adapter.get_process_input_code("ascend", "ascendc")
assert "def process_input" in code
assert "x.npu()" in code
def test_get_set_seed_code(self):
"""Test set seed code generation."""
adapter = get_framework_adapter("torch")
code = adapter.get_set_seed_code("cuda")
assert "torch.manual_seed(0)" in code
assert "torch.npu.manual_seed" not in code
code_ascend = adapter.get_set_seed_code("ascend")
assert "torch.manual_seed(0)" in code_ascend
assert "torch.npu.manual_seed(0)" in code_ascend
def test_get_limit(self):
"""Test precision limit calculation."""
adapter = get_framework_adapter("torch")
import torch
assert adapter.get_limit(torch.float16) == 0.004
assert adapter.get_limit(torch.bfloat16) == 0.03
assert adapter.get_limit(torch.int8) == 0.01
assert adapter.get_limit(torch.float32) == 0.02
def test_get_tensor_type(self):
"""Test tensor type."""
adapter = get_framework_adapter("torch")
import torch
assert adapter.get_tensor_type() == torch.Tensor
def test_get_binary_io_functions(self):
"""Test binary I/O functions generation."""
adapter = get_framework_adapter("torch")
code = adapter.get_binary_io_functions("test_op")
assert "def save_tensor" in code
assert "def load_tensor" in code
assert "def gen_binary_data" in code
assert "def load_binary_data" in code
assert "test_op" in code
class TestFrameworkAdapterFactory:
"""Test Framework Adapter Factory."""
def test_get_framework_adapter_torch(self):
"""Test getting PyTorch adapter."""
adapter = get_framework_adapter("torch")
assert adapter is not None
assert adapter.__class__.__name__ == "FrameworkAdapterTorch"
def test_get_framework_adapter_invalid(self):
"""Test getting invalid framework adapter."""
with pytest.raises(ValueError, match="Unsupported framework"):
get_framework_adapter("invalid")