# Copyright 2026 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Setup script for MFUSE MLIR Python package."""

import os
import shutil
import subprocess
from pathlib import Path

from setuptools import setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py


class CMakeBuild(build_ext):
    """Custom build command to compile C++ extensions."""

    def run(self):
        """Build C++ extensions using CMake."""
        Path(self.build_temp).mkdir(parents=True, exist_ok=True)

        cmake_args = []

        # Get BUILD_TYPE from environment variable, default to Release
        build_type = os.environ.get("BUILD_TYPE", "Release")
        cmake_args.append(f"-DCMAKE_BUILD_TYPE={build_type}")
        cmake_prefix_path = os.environ.get("CMAKE_PREFIX_PATH")
        if cmake_prefix_path:
            cmake_args.append(f"-DCMAKE_PREFIX_PATH={cmake_prefix_path}")

        enable_asan = os.environ.get("ENABLE_ASAN", "OFF")
        cmake_args.append(f"-DENABLE_ASAN={enable_asan}")

        # Configure with CMake
        if os.environ.get("INC_BUILD", "0") != "1":
            subprocess.check_call(
                [
                    "cmake",
                    "-S",
                    str(Path(__file__).parent),
                    "-B",
                    self.build_temp,
                ]
                + cmake_args
            )

        # Build with CMake
        build_jobs = os.environ.get("BUILD_JOBS", "8")
        build_tests = os.environ.get("BUILD_TESTS", "OFF")
        subprocess.check_call(["cmake", "--build", self.build_temp, "-j", build_jobs])

        # Copy the generated mfusion package
        python_package_dir = Path(self.build_temp) / "python_packages" / "mfusion"
        target_dir = Path(self.build_lib) / "mfusion"
        if target_dir.exists():
            shutil.rmtree(target_dir)
        shutil.copytree(python_package_dir, target_dir)

        # Merge Python sources from the repo.
        source_python_dir = Path(__file__).parent / "python" / "mfusion"
        if source_python_dir.exists():
            shutil.copytree(source_python_dir, target_dir, dirs_exist_ok=True)

        # Copy tests
        if build_tests == "ON":
            tests_dir = Path(self.build_temp) / "tests"
            if tests_dir.exists():
                build_dir = Path(__file__).parent / "build"
                if not build_dir.exists():
                    raise RuntimeError(f"Directory does not exist: {build_dir}. Please run build.sh to create it.")
                dst_tests = build_dir / "tests"
                if dst_tests.exists():
                    shutil.rmtree(dst_tests)
                shutil.copytree(tests_dir, dst_tests)

        # Copy mfusion-opt executable to _mlir_libs
        mfusion_opt_src = Path(self.build_temp) / "bin" / "mfusion-opt"
        mfusion_opt_dst = target_dir / "_mlir_libs" / "mfusion-opt"
        if mfusion_opt_src.exists():
            shutil.copy2(mfusion_opt_src, mfusion_opt_dst, follow_symlinks=False)
            print(f"Copied mfusion-opt to {mfusion_opt_dst}")
        else:
            print(f"Warning: mfusion-opt not found at {mfusion_opt_src}")


class BuildPyWithExt(build_py):
    def run(self):
        super().run()
        self.run_command("build_ext")


setup(
    cmdclass={
        "build_ext": CMakeBuild,
        "build_py": BuildPyWithExt,
    }
)