"""Unit tests for DSL Adapters."""
import pytest
from akg_agents.op.verifier.adapters.factory import get_dsl_adapter, get_framework_adapter
class TestDSLAdapterTritonCuda:
"""Test Triton CUDA DSL Adapter."""
def test_get_import_statements(self):
"""Test import statements generation."""
adapter = get_dsl_adapter("triton_cuda")
imports = adapter.get_import_statements("torch")
assert "import triton" in imports
assert "import triton.language as tl" in imports
def test_get_impl_import(self):
"""Test implementation import."""
adapter = get_dsl_adapter("triton_cuda")
imports = adapter.get_impl_import("test_op", "test_func")
assert "from test_op_triton_cuda_impl import ModelNew" in imports
def test_call_impl(self):
"""Test call implementation code generation."""
adapter = get_dsl_adapter("triton_cuda")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "impl_output = impl_model(*inputs)" in code
def test_needs_binary_io(self):
"""Test binary I/O requirement."""
adapter = get_dsl_adapter("triton_cuda")
assert adapter.needs_binary_io() is False
def test_needs_compilation(self):
"""Test compilation requirement."""
adapter = get_dsl_adapter("triton_cuda")
assert adapter.needs_compilation() is False
def test_benchmark_impl(self):
"""Test benchmark code generation."""
adapter = get_dsl_adapter("triton_cuda")
code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "cuda", "test_op")
assert "triton.testing.do_bench" in code
assert "warmup=10" in code
assert "rep=100" in code
class TestDSLAdapterTritonAscend:
"""Test Triton Ascend DSL Adapter."""
def test_get_import_statements(self):
"""Test import statements generation."""
adapter = get_dsl_adapter("triton_ascend")
imports = adapter.get_import_statements("torch")
assert "import triton" in imports
assert "apply_triton_patches" in imports
def test_get_impl_import(self):
"""Test implementation import."""
adapter = get_dsl_adapter("triton_ascend")
imports = adapter.get_impl_import("test_op", "test_func")
assert "from test_op_triton_ascend_impl import ModelNew" in imports
def test_call_impl(self):
"""Test call implementation code generation."""
adapter = get_dsl_adapter("triton_ascend")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "impl_output = impl_model(*inputs)" in code
def test_benchmark_impl_ascend(self):
"""Test benchmark code generation for Ascend."""
adapter = get_dsl_adapter("triton_ascend")
code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "ascend", "test_op")
assert "profiler_npu" in code
assert "get_collected_config_timings" in code
assert "autotune_info_case_" in code
def test_get_special_setup_code(self):
"""Test special setup code."""
adapter = get_dsl_adapter("triton_ascend")
code = adapter.get_special_setup_code()
assert "apply_triton_patches" in code
class TestDSLAdapterAscendC:
"""Test AscendC DSL Adapter."""
def test_get_import_statements(self):
"""Test import statements generation."""
adapter = get_dsl_adapter("ascendc")
imports = adapter.get_import_statements("torch")
assert "import torch_npu" in imports
assert "import subprocess" in imports
def test_get_impl_import(self):
"""Test implementation import (should be empty for AscendC)."""
adapter = get_dsl_adapter("ascendc")
imports = adapter.get_impl_import("test_op", "test_func")
assert imports == ""
def test_call_impl(self):
"""Test call implementation code generation."""
adapter = get_dsl_adapter("ascendc")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "subprocess.run" in code
assert "run.sh" in code
assert "import test_func" in code
assert "run_test_func" in code
assert "test_func.run_test_func(*inputs)" in code
def test_needs_compilation(self):
"""Test compilation requirement."""
adapter = get_dsl_adapter("ascendc")
assert adapter.needs_compilation() is True
class TestDSLAdapterCpp:
"""Test C++ DSL Adapter."""
def test_get_import_statements(self):
"""Test import statements generation."""
adapter = get_dsl_adapter("cpp")
imports = adapter.get_import_statements("torch")
assert "import torch" in imports
def test_get_impl_import(self):
"""Test implementation import."""
adapter = get_dsl_adapter("cpp")
imports = adapter.get_impl_import("test_op", "test_func")
assert "from test_op_cpp_impl import ModelNew" in imports
def test_call_impl(self):
"""Test call implementation code generation."""
adapter = get_dsl_adapter("cpp")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "impl_output = impl_model(*inputs)" in code
def test_benchmark_impl(self):
"""Test benchmark code generation."""
adapter = get_dsl_adapter("cpp")
code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "cpu", "test_op")
assert "time.perf_counter" in code
assert "warmup" in code.lower()
class TestDSLAdapterCudaC:
"""Test CUDA C DSL Adapter."""
def test_get_import_statements(self):
"""CUDA C adapter should emit torch cpp extension imports."""
adapter = get_dsl_adapter("cuda_c")
imports = adapter.get_import_statements("torch")
assert "from torch.utils.cpp_extension import load_inline" in imports
def test_get_impl_import(self):
"""CUDA C adapter now imports ModelNew."""
adapter = get_dsl_adapter("cuda_c")
imports = adapter.get_impl_import("test_op", "test_func")
assert "from test_op_cuda_c_impl import ModelNew" in imports
def test_create_impl_module(self):
"""Impl model should be instantiated once and moved to device."""
adapter = get_dsl_adapter("cuda_c")
framework_adapter = get_framework_adapter("torch")
code = adapter.create_impl_module("torch", framework_adapter)
assert "impl_model = ModelNew(*init_params)" in code
assert "impl_model = impl_model.to(device)" in code
def test_call_impl(self):
"""Call site should reuse impl_model."""
adapter = get_dsl_adapter("cuda_c")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "impl_output = impl_model(*inputs)" in code
def test_benchmark_impl(self):
"""Benchmark section should invoke impl_model."""
adapter = get_dsl_adapter("cuda_c")
code = adapter.benchmark_impl("test_func", "inputs", 5, 50, "cuda", "test_op")
assert "def cuda_c_benchmark_fn()" in code
assert "impl_model(*inputs)" in code
assert "torch.cuda.synchronize()" in code
class TestDSLAdapterTilelangCuda:
"""Test TileLang CUDA DSL Adapter."""
def test_get_import_statements(self):
adapter = get_dsl_adapter("tilelang_cuda")
imports = adapter.get_import_statements("torch")
assert "import tilelang.language as T" in imports
def test_get_impl_import(self):
adapter = get_dsl_adapter("tilelang_cuda")
imports = adapter.get_impl_import("test_op", "test_func")
assert "from test_op_tilelang_cuda_impl import ModelNew" in imports
def test_create_impl_module(self):
adapter = get_dsl_adapter("tilelang_cuda")
framework_adapter = get_framework_adapter("torch")
code = adapter.create_impl_module("torch", framework_adapter)
assert "impl_model = ModelNew(*init_params)" in code
assert "impl_model = impl_model.to(device)" in code
def test_call_impl(self):
adapter = get_dsl_adapter("tilelang_cuda")
framework_adapter = get_framework_adapter("torch")
code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
assert "impl_output = impl_model(*inputs)" in code
def test_benchmark_impl(self):
adapter = get_dsl_adapter("tilelang_cuda")
code = adapter.benchmark_impl("test_func", "inputs", 5, 50, "cuda", "test_op")
assert "tilelang_cuda_benchmark_fn" in code
assert "impl_model(*inputs)" in code
assert "torch.cuda.synchronize()" in code
class TestDSLAdapterFactory:
"""Test DSL Adapter Factory."""
def test_get_dsl_adapter_triton_cuda(self):
"""Test getting Triton CUDA adapter."""
adapter = get_dsl_adapter("triton_cuda")
assert adapter is not None
assert adapter.__class__.__name__ == "DSLAdapterTritonCuda"
def test_get_dsl_adapter_triton_ascend(self):
"""Test getting Triton Ascend adapter."""
adapter = get_dsl_adapter("triton_ascend")
assert adapter is not None
assert adapter.__class__.__name__ == "DSLAdapterTritonAscend"
def test_get_dsl_adapter_ascendc(self):
"""Test getting AscendC adapter."""
adapter = get_dsl_adapter("ascendc")
assert adapter is not None
assert adapter.__class__.__name__ == "DSLAdapterAscendC"
def test_get_dsl_adapter_invalid(self):
"""Test getting invalid DSL adapter."""
with pytest.raises(ValueError, match="Unsupported DSL"):
get_dsl_adapter("invalid")