# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for DSL Adapters."""

import pytest
from akg_agents.op.verifier.adapters.factory import get_dsl_adapter, get_framework_adapter


class TestDSLAdapterTritonCuda:
    """Test Triton CUDA DSL Adapter."""
    
    def test_get_import_statements(self):
        """Test import statements generation."""
        adapter = get_dsl_adapter("triton_cuda")
        imports = adapter.get_import_statements("torch")
        assert "import triton" in imports
        assert "import triton.language as tl" in imports
    
    def test_get_impl_import(self):
        """Test implementation import."""
        adapter = get_dsl_adapter("triton_cuda")
        imports = adapter.get_impl_import("test_op", "test_func")
        # 现在统一使用 ModelNew 类格式,模块名带 _impl 后缀
        assert "from test_op_triton_cuda_impl import ModelNew" in imports
    
    def test_call_impl(self):
        """Test call implementation code generation."""
        adapter = get_dsl_adapter("triton_cuda")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        # 现在使用 impl_model 调用(ModelNew 实例)
        assert "impl_output = impl_model(*inputs)" in code
    
    def test_needs_binary_io(self):
        """Test binary I/O requirement."""
        adapter = get_dsl_adapter("triton_cuda")
        assert adapter.needs_binary_io() is False
    
    def test_needs_compilation(self):
        """Test compilation requirement."""
        adapter = get_dsl_adapter("triton_cuda")
        assert adapter.needs_compilation() is False
    
    def test_benchmark_impl(self):
        """Test benchmark code generation."""
        adapter = get_dsl_adapter("triton_cuda")
        code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "cuda", "test_op")
        assert "triton.testing.do_bench" in code
        assert "warmup=10" in code
        assert "rep=100" in code


class TestDSLAdapterTritonAscend:
    """Test Triton Ascend DSL Adapter."""
    
    def test_get_import_statements(self):
        """Test import statements generation."""
        adapter = get_dsl_adapter("triton_ascend")
        imports = adapter.get_import_statements("torch")
        assert "import triton" in imports
        assert "apply_triton_patches" in imports
    
    def test_get_impl_import(self):
        """Test implementation import."""
        adapter = get_dsl_adapter("triton_ascend")
        imports = adapter.get_impl_import("test_op", "test_func")
        # 现在统一使用 ModelNew 类格式,模块名带 _impl 后缀
        assert "from test_op_triton_ascend_impl import ModelNew" in imports
    
    def test_call_impl(self):
        """Test call implementation code generation."""
        adapter = get_dsl_adapter("triton_ascend")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        # 现在使用 impl_model 调用(ModelNew 实例)
        assert "impl_output = impl_model(*inputs)" in code
    
    def test_benchmark_impl_ascend(self):
        """Test benchmark code generation for Ascend."""
        adapter = get_dsl_adapter("triton_ascend")
        code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "ascend", "test_op")
        assert "profiler_npu" in code
        assert "get_collected_config_timings" in code
        assert "autotune_info_case_" in code
    
    def test_get_special_setup_code(self):
        """Test special setup code."""
        adapter = get_dsl_adapter("triton_ascend")
        code = adapter.get_special_setup_code()
        assert "apply_triton_patches" in code


class TestDSLAdapterAscendC:
    """Test AscendC DSL Adapter."""
    
    def test_get_import_statements(self):
        """Test import statements generation."""
        adapter = get_dsl_adapter("ascendc")
        imports = adapter.get_import_statements("torch")
        assert "import torch_npu" in imports
        assert "import subprocess" in imports
    
    def test_get_impl_import(self):
        """Test implementation import (should be empty for AscendC)."""
        adapter = get_dsl_adapter("ascendc")
        imports = adapter.get_impl_import("test_op", "test_func")
        assert imports == ""
    
    def test_call_impl(self):
        """Test call implementation code generation."""
        adapter = get_dsl_adapter("ascendc")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        assert "subprocess.run" in code
        assert "run.sh" in code
        # impl_func_name 和 inputs 现在由 f-string 直接替换
        assert "import test_func" in code
        assert "run_test_func" in code
        assert "test_func.run_test_func(*inputs)" in code
    
    def test_needs_compilation(self):
        """Test compilation requirement."""
        adapter = get_dsl_adapter("ascendc")
        assert adapter.needs_compilation() is True


class TestDSLAdapterCpp:
    """Test C++ DSL Adapter."""
    
    def test_get_import_statements(self):
        """Test import statements generation."""
        adapter = get_dsl_adapter("cpp")
        imports = adapter.get_import_statements("torch")
        assert "import torch" in imports
    
    def test_get_impl_import(self):
        """Test implementation import."""
        adapter = get_dsl_adapter("cpp")
        imports = adapter.get_impl_import("test_op", "test_func")
        # 现在统一使用 ModelNew 类格式,模块名带 _impl 后缀
        assert "from test_op_cpp_impl import ModelNew" in imports
    
    def test_call_impl(self):
        """Test call implementation code generation."""
        adapter = get_dsl_adapter("cpp")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        # 现在使用 impl_model 调用(ModelNew 实例)
        assert "impl_output = impl_model(*inputs)" in code
    
    def test_benchmark_impl(self):
        """Test benchmark code generation."""
        adapter = get_dsl_adapter("cpp")
        code = adapter.benchmark_impl("test_func", "inputs", 10, 100, "cpu", "test_op")
        assert "time.perf_counter" in code
        assert "warmup" in code.lower()


class TestDSLAdapterCudaC:
    """Test CUDA C DSL Adapter."""
    
    def test_get_import_statements(self):
        """CUDA C adapter should emit torch cpp extension imports."""
        adapter = get_dsl_adapter("cuda_c")
        imports = adapter.get_import_statements("torch")
        assert "from torch.utils.cpp_extension import load_inline" in imports
    
    def test_get_impl_import(self):
        """CUDA C adapter now imports ModelNew."""
        adapter = get_dsl_adapter("cuda_c")
        imports = adapter.get_impl_import("test_op", "test_func")
        # 模块名带 _impl 后缀
        assert "from test_op_cuda_c_impl import ModelNew" in imports
    
    def test_create_impl_module(self):
        """Impl model should be instantiated once and moved to device."""
        adapter = get_dsl_adapter("cuda_c")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.create_impl_module("torch", framework_adapter)
        assert "impl_model = ModelNew(*init_params)" in code
        assert "impl_model = impl_model.to(device)" in code
    
    def test_call_impl(self):
        """Call site should reuse impl_model."""
        adapter = get_dsl_adapter("cuda_c")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        assert "impl_output = impl_model(*inputs)" in code
    
    def test_benchmark_impl(self):
        """Benchmark section should invoke impl_model."""
        adapter = get_dsl_adapter("cuda_c")
        code = adapter.benchmark_impl("test_func", "inputs", 5, 50, "cuda", "test_op")
        assert "def cuda_c_benchmark_fn()" in code
        assert "impl_model(*inputs)" in code
        assert "torch.cuda.synchronize()" in code


class TestDSLAdapterTilelangCuda:
    """Test TileLang CUDA DSL Adapter."""
    
    def test_get_import_statements(self):
        adapter = get_dsl_adapter("tilelang_cuda")
        imports = adapter.get_import_statements("torch")
        assert "import tilelang.language as T" in imports
    
    def test_get_impl_import(self):
        adapter = get_dsl_adapter("tilelang_cuda")
        imports = adapter.get_impl_import("test_op", "test_func")
        # 模块名带 _impl 后缀
        assert "from test_op_tilelang_cuda_impl import ModelNew" in imports
    
    def test_create_impl_module(self):
        adapter = get_dsl_adapter("tilelang_cuda")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.create_impl_module("torch", framework_adapter)
        assert "impl_model = ModelNew(*init_params)" in code
        assert "impl_model = impl_model.to(device)" in code
    
    def test_call_impl(self):
        adapter = get_dsl_adapter("tilelang_cuda")
        framework_adapter = get_framework_adapter("torch")
        code = adapter.call_impl("test_func", "inputs", 0, framework_adapter, "test_op")
        assert "impl_output = impl_model(*inputs)" in code
    
    def test_benchmark_impl(self):
        adapter = get_dsl_adapter("tilelang_cuda")
        code = adapter.benchmark_impl("test_func", "inputs", 5, 50, "cuda", "test_op")
        assert "tilelang_cuda_benchmark_fn" in code
        assert "impl_model(*inputs)" in code
        assert "torch.cuda.synchronize()" in code


class TestDSLAdapterFactory:
    """Test DSL Adapter Factory."""
    
    def test_get_dsl_adapter_triton_cuda(self):
        """Test getting Triton CUDA adapter."""
        adapter = get_dsl_adapter("triton_cuda")
        assert adapter is not None
        assert adapter.__class__.__name__ == "DSLAdapterTritonCuda"
    
    def test_get_dsl_adapter_triton_ascend(self):
        """Test getting Triton Ascend adapter."""
        adapter = get_dsl_adapter("triton_ascend")
        assert adapter is not None
        assert adapter.__class__.__name__ == "DSLAdapterTritonAscend"
    
    def test_get_dsl_adapter_ascendc(self):
        """Test getting AscendC adapter."""
        adapter = get_dsl_adapter("ascendc")
        assert adapter is not None
        assert adapter.__class__.__name__ == "DSLAdapterAscendC"
    
    def test_get_dsl_adapter_invalid(self):
        """Test getting invalid DSL adapter."""
        with pytest.raises(ValueError, match="Unsupported DSL"):
            get_dsl_adapter("invalid")