import os
import platform
import re
import contextlib
import shlex
import shutil
import subprocess
import sys
import sysconfig
import tarfile
import zipfile
import urllib.request
import json
import glob

from io import BytesIO
from distutils.command.clean import clean
from pathlib import Path
from typing import Optional

from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from setuptools.command.egg_info import egg_info
from setuptools.command.install import install
from setuptools.command.sdist import sdist

from dataclasses import dataclass

import pybind11

try:
    from setuptools.command.bdist_wheel import bdist_wheel
except ImportError:
    from wheel.bdist_wheel import bdist_wheel

try:
    from setuptools.command.editable_wheel import editable_wheel
except ImportError:
    # create a dummy class, since there is no command to override
    class editable_wheel:
        pass


sys.path.insert(0, os.path.dirname(__file__))

from python.build_helpers import get_base_dir, get_cmake_dir


triton_dir = os.path.dirname(os.path.abspath(__file__))

os.environ.setdefault("TRITON_BUILD_WITH_CCACHE", "true")
os.environ.setdefault("TRITON_BUILD_WITH_CLANG_LLD", "true")
os.environ.setdefault("TRITON_BUILD_PROTON", "OFF")
os.environ.setdefault("TRITON_WHEEL_NAME", "triton-ascend")
os.environ.setdefault("TRITON_APPEND_CMAKE_ARGS", "-DTRITON_BUILD_UT=OFF")


def is_git_repo():
    """Return True if this file resides in a git repository"""
    return (Path(__file__).parent / ".git").is_dir()


@dataclass
class Backend:
    name: str
    src_dir: str
    backend_dir: str
    language_dir: Optional[str]
    tools_dir: Optional[str]
    install_dir: str
    is_external: bool


class BackendInstaller:

    @staticmethod
    def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
        # Initialize submodule if there is one for in-tree backends.
        if not is_external:
            root_dir = "third_party"
            assert backend_name in os.listdir(
                root_dir), f"{backend_name} is requested for install but not present in {root_dir}"

            if is_git_repo():
                try:
                    subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
                                   stdout=subprocess.DEVNULL, cwd=root_dir)
                except subprocess.CalledProcessError:
                    pass
                except FileNotFoundError:
                    pass

            backend_src_dir = os.path.join(root_dir, backend_name)

        backend_path = os.path.join(backend_src_dir, "backend")
        assert os.path.exists(backend_path), f"{backend_path} does not exist!"

        language_dir = os.path.join(backend_src_dir, "language")
        if not os.path.exists(language_dir):
            language_dir = None

        tools_dir = os.path.join(backend_src_dir, "tools")
        if not os.path.exists(tools_dir):
            tools_dir = None

        for file in ["compiler.py", "driver.py"]:
            assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"

        install_dir = os.path.join(os.path.dirname(__file__), "python", "triton", "backends", backend_name)

        return Backend(name=backend_name, src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
                       tools_dir=tools_dir, install_dir=install_dir, is_external=is_external)

    # Copy all in-tree backends under triton/third_party.
    @staticmethod
    def copy(active):
        return [BackendInstaller.prepare(backend) for backend in active]

    # Copy all external plugins provided by the `TRITON_PLUGIN_DIRS` env var.
    # TRITON_PLUGIN_DIRS is a semicolon-separated list of paths to the plugins.
    # Expect to find the name of the backend under dir/backend/name.conf
    @staticmethod
    def copy_externals():
        backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
        if backend_dirs is None:
            return []
        backend_dirs = backend_dirs.strip().split(";")
        backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs]
        return [
            BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
            for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
        ]


# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
def check_env_flag(name: str, default: str = "") -> bool:
    return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]


def get_build_type():
    if check_env_flag("DEBUG"):
        return "Debug"
    elif check_env_flag("REL_WITH_DEB_INFO"):
        return "RelWithDebInfo"
    elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
        return "TritonRelBuildWithAsserts"
    elif check_env_flag("TRITON_BUILD_WITH_O1"):
        return "TritonBuildWithO1"
    else:
        # TODO: change to release when stable enough
        return "TritonRelBuildWithAsserts"


def get_env_with_keys(key: list):
    for k in key:
        if k in os.environ:
            return os.environ[k]
    return ""


def is_offline_build() -> bool:
    """
    Downstream projects and distributions which bootstrap their own dependencies from scratch
    and run builds in offline sandboxes
    may set `TRITON_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading
    pinned dependencies from the internet or at using dependencies vendored in-tree.

    Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).
    Missing dependencies lead to an early abortion.
    Dependencies' compatibility is not verified.

    Note that this flag isn't tested by the CI and does not provide any guarantees.
    """
    return check_env_flag("TRITON_OFFLINE_BUILD", "")


# --- third party packages -----


@dataclass
class Package:
    package: str
    name: str
    url: str
    include_flag: str
    lib_flag: str
    syspath_var_name: str
    sym_name: Optional[str] = None


# json
def get_json_package_info():
    url = "https://github.com/nlohmann/json/releases/download/v3.11.3/include.zip"
    return Package("json", "", url, "JSON_INCLUDE_DIR", "", "JSON_SYSPATH")


def is_linux_os(os_id):
    if os.path.exists("/etc/os-release"):
        with open("/etc/os-release", "r") as f:
            os_release_content = f.read()
            return f'ID="{os_id}"' in os_release_content
    return False


# llvm
def get_llvm_package_info():
    system = platform.system()
    try:
        arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
    except KeyError:
        arch = platform.machine()
    if (env_system_suffix := os.environ.get("TRITON_LLVM_SYSTEM_SUFFIX", None)):
        system_suffix = env_system_suffix
    elif system == "Darwin":
        system_suffix = f"macos-{arch}"
    elif system == "Linux":
        if arch == 'arm64' and is_linux_os('almalinux'):
            system_suffix = 'almalinux-arm64'
        elif arch == "arm64":
            system_suffix = 'ubuntu-arm64'
        elif arch == 'x64':
            vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
            vglibc = vglibc[0] * 100 + vglibc[1]
            if vglibc > 228:
                # Ubuntu 24 LTS (v2.39)
                # Ubuntu 22 LTS (v2.35)
                # Ubuntu 20 LTS (v2.31)
                system_suffix = "ubuntu-x64"
            elif vglibc > 217:
                # Manylinux_2.28 (v2.28)
                # AlmaLinux 8 (v2.28)
                system_suffix = "almalinux-x64"
            else:
                # Manylinux_2014 (v2.17)
                # CentOS 7 (v2.17)
                system_suffix = "centos-x64"
        else:
            print(
                f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
            )
            return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
    else:
        print(
            f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
        )
        return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
    # use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
    # release_suffix = "assert" if use_assert_enabled_llvm else "release"
    llvm_hash_path = os.path.join(get_base_dir(), "cmake", "llvm-hash.txt")
    with open(llvm_hash_path, "r") as llvm_hash_file:
        rev = llvm_hash_file.read(8)
    name = f"llvm-{rev}-{system_suffix}"
    # Create a stable symlink that doesn't include revision
    sym_name = f"llvm-{system_suffix}"
    url = f"https://triton-ascend-artifacts.obs.myhuaweicloud.com/llvm-builds/{name}.tar.gz"
    return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH", sym_name=sym_name)


def open_url(url):
    user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'
    headers = {
        'User-Agent': user_agent,
    }
    request = urllib.request.Request(url, None, headers)
    # Set timeout to 300 seconds to prevent the request from hanging forever.
    return urllib.request.urlopen(request, timeout=300)


# ---- package data ---


def get_triton_cache_path():
    user_home = os.getenv("TRITON_HOME")
    if not user_home:
        user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None
    if not user_home:
        raise RuntimeError("Could not find user home directory")
    return os.path.join(user_home, ".triton")


def update_symlink(link_path, source_path):
    source_path = Path(source_path)
    link_path = Path(link_path)

    if link_path.is_symlink():
        link_path.unlink()
    elif link_path.exists():
        shutil.rmtree(link_path)

    print(f"creating symlink: {link_path} -> {source_path}", file=sys.stderr)
    link_path.absolute().parent.mkdir(parents=True, exist_ok=True)  # Ensure link's parent directory exists
    link_path.symlink_to(source_path.absolute(), target_is_directory=True)


def get_thirdparty_packages(packages: list):
    triton_cache_path = get_triton_cache_path()
    thirdparty_cmake_args = []
    for p in packages:
        package_root_dir = os.path.join(triton_cache_path, p.package)
        package_dir = os.path.join(package_root_dir, p.name)
        if os.environ.get(p.syspath_var_name):
            package_dir = os.environ[p.syspath_var_name]
        version_file_path = os.path.join(package_dir, "version.txt")

        input_defined = p.syspath_var_name in os.environ
        input_exists = os.path.exists(version_file_path)
        input_compatible = input_exists and Path(version_file_path).read_text() == p.url

        if is_offline_build() and not input_defined:
            raise RuntimeError(f"Requested an offline build but {p.syspath_var_name} is not set")
        if not is_offline_build() and not input_defined and not input_compatible:
            with contextlib.suppress(Exception):
                shutil.rmtree(package_root_dir)
            os.makedirs(package_root_dir, exist_ok=True)
            print(f'downloading and extracting {p.url} ...')
            with open_url(p.url) as response:
                if p.url.endswith(".zip"):
                    file_bytes = BytesIO(response.read())
                    with zipfile.ZipFile(file_bytes, "r") as file:
                        file.extractall(path=package_root_dir)
                else:
                    with tarfile.open(fileobj=response, mode="r|*") as file:
                        file.extractall(path=package_root_dir)
            # write version url to package_dir
            with open(os.path.join(package_dir, "version.txt"), "w") as f:
                f.write(p.url)
        if p.include_flag:
            thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
        if p.lib_flag:
            thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
        if p.syspath_var_name:
            thirdparty_cmake_args.append(f"-D{p.syspath_var_name}={package_dir}")
        if p.sym_name is not None:
            sym_link_path = os.path.join(package_root_dir, p.sym_name)
            update_symlink(sym_link_path, package_dir)

    return thirdparty_cmake_args


def download_and_copy(name, src_func, dst_path, variable, version, url_func):
    if is_offline_build():
        return
    triton_cache_path = get_triton_cache_path()
    if variable in os.environ:
        return
    base_dir = os.path.dirname(__file__)
    system = platform.system()
    arch = platform.machine()
    # NOTE: This might be wrong for jetson if both grace chips and jetson chips return aarch64
    arch = {"arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch)
    supported = {"Linux": "linux", "Darwin": "linux"}
    url = url_func(supported[system], arch, version)
    src_path = src_func(supported[system], arch, version)
    tmp_path = os.path.join(triton_cache_path, "nvidia", name)  # path to cache the download
    dst_path = os.path.join(base_dir, "third_party", "nvidia", "backend", dst_path)  # final binary path
    src_path = os.path.join(tmp_path, src_path)
    download = not os.path.exists(src_path)
    if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None:
        curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
        curr_version = re.search(r"V([.|\d]+)", curr_version)
        assert curr_version is not None, f"No version information for {dst_path}"
        download = download or curr_version.group(1) != version
    if download:
        print(f'downloading and extracting {url} ...')
        file = tarfile.open(fileobj=open_url(url), mode="r|*")
        file.extractall(path=tmp_path)
    os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
    print(f'copy {src_path} to {dst_path} ...')
    if os.path.isdir(src_path):
        shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
    else:
        shutil.copy(src_path, dst_path)


# ---- cmake extension ----


class CMakeClean(clean):

    def initialize_options(self):
        clean.initialize_options(self)
        self.build_temp = get_cmake_dir()


class CMakeBuildPy(build_py):

    def run(self) -> None:
        self.run_command('build_ext')
        return super().run()


class CMakeExtension(Extension):

    def __init__(self, name, path, sourcedir=""):
        Extension.__init__(self, name, sources=[])
        self.sourcedir = os.path.abspath(sourcedir)
        self.path = path


class CMakeBuild(build_ext):
    user_options = build_ext.user_options + \
        [('base-dir=', None, 'base directory of Triton')]

    def initialize_options(self):
        build_ext.initialize_options(self)
        self.base_dir = get_base_dir()

    def finalize_options(self):
        build_ext.finalize_options(self)

    def setup_coverage_env(self):
        """Setting environment variables required for the hitest coverage tool"""
        # To enable the hitest cov tool, you need to set the following three environment variables.
        hitest_home = os.getenv('HITEST_HOME', '/opt/hitest/linux_avatar_x86_64')
        hitest_user_account = os.getenv('HITEST_USER_ACCOUNT', 'a00000000')
        lltcov_rootpath = os.getenv('LLTCOV_ROOTPATH', '/opt/covdata')  # Path to the output coverage binary file

        # hitest default environment variables
        coverage_env_vars = {
            'HitestHome': hitest_home,
            'isOverlappedCompile': '0',
            'PlatformToken': 'BOARD',
            'gcovmode': '0',
            'TimerPolicy': '1',
            'TimeInterval': '60',
            'SignalPolicy': '1',
            'SignalNUM': '34',
            'lltwrapper_cfg': '0',
            'HITEST_AGENT_INSIDE': '1',
            'USE_HLLT_COVERAGE': '1',
            'USE_HLLT_TESTCASE': '0',
            'simplemode': '0',
            'ncs_coverage_stub_mold': '1',
            'HITEST_ENABLE_SOKCET': '0',
            'hitest_disable_cfg': '0',
            'hitest_disable_dfg': '1',
            'hitest_disable_ir': '1',
            'HITEST_DISABLE_MACRO': '0',
            'HITEST_REMOVE_INCLUDE_DIR': '0',
            'HITEST_AGENT_SET_THREADNAME_PRCTL': '1',
            'HITEST_INST_HEADER_FILE': '0',
            'HITEST_USER_ACCOUNT': hitest_user_account,
            'lltcovRootpath': lltcov_rootpath,
            'HITEST_COVSTUB_ROOT_DIR': f'{hitest_home}/apache-tomcat-8.0.39/webapps/datasource/Container_Default/base',
            'HITEST_EXEC_CMD_WITH_FILE': '1',
            'HITEST_PRINT_LOG_ENABLE': '1',
        }

        for key, value in coverage_env_vars.items():
            os.environ[key] = value

        current_path = os.environ.get('PATH', '')
        os.environ['PATH'] = f'{hitest_home}:{current_path}'

        current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
        os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'

        print(f"The currently set environment variables for the hitest coverage tool are read.")
        print(f"  HitestHome: {hitest_home} (environment variables HITEST_HOME)")
        print(f"  HITEST_USER_ACCOUNT: {hitest_user_account} (environment variables HITEST_USER_ACCOUNT)")
        print(f"  lltcovRootpath: {lltcov_rootpath} (environment variables LLTCOV_ROOTPATH)")

        current_path = os.environ.get('PATH', '')
        os.environ['PATH'] = f'{hitest_home}:{current_path}'
        current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
        os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'

    def run(self):
        download_and_copy_dependencies()

        try:
            out = subprocess.check_output(["cmake", "--version"])
        except OSError:
            raise RuntimeError("CMake must be installed to build the following extensions: " +
                               ", ".join(e.name for e in self.extensions))

        match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
        cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
        if (cmake_major, cmake_minor) < (3, 20):
            raise RuntimeError("CMake >= 3.20 is required")

        # To enable the hitest coverage tool, you need to set the environment variable TRITON_ENABLE_COVERAGE_HITEST=1
        enable_hitest = os.getenv('TRITON_ENABLE_COVERAGE_HITEST', '0').lower() in ('1', 'on', 'true')
        if enable_hitest:
            self.setup_coverage_env()
            current_append = os.environ.get('TRITON_APPEND_CMAKE_ARGS', '')
            if current_append:
                os.environ['TRITON_APPEND_CMAKE_ARGS'] = current_append + " -DTRITON_ENABLE_COVERAGE_HITEST=ON"
            else:
                os.environ['TRITON_APPEND_CMAKE_ARGS'] = "-DTRITON_ENABLE_COVERAGE_HITEST=ON"
        else:
            # clean up existing HITEST_* environment variables to avoid pollution.
            for key in list(os.environ.keys()):
                if key.startswith('HITEST_') or key in ['HitestHome', 'lltcovRootpath']:
                    del os.environ[key]

        for ext in self.extensions:
            self.build_extension(ext)

    def get_pybind11_cmake_args(self):
        pybind11_sys_path = get_env_with_keys(["PYBIND11_SYSPATH"])
        if pybind11_sys_path:
            pybind11_include_dir = os.path.join(pybind11_sys_path, "include")
        else:
            pybind11_include_dir = pybind11.get_include()
        return [f"-Dpybind11_INCLUDE_DIR='{pybind11_include_dir}'", f"-Dpybind11_DIR='{pybind11.get_cmake_dir()}'"]

    def get_proton_cmake_args(self):
        cmake_args = get_thirdparty_packages([get_json_package_info()])
        cmake_args += self.get_pybind11_cmake_args()
        cupti_include_dir = get_env_with_keys(["TRITON_CUPTI_INCLUDE_PATH"])
        if cupti_include_dir == "":
            cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include")
        cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir]
        roctracer_include_dir = get_env_with_keys(["TRITON_ROCTRACER_INCLUDE_PATH"])
        if roctracer_include_dir == "":
            roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include")
        cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir]
        return cmake_args

    def build_extension(self, ext):
        lit_dir = shutil.which('lit')
        ninja_dir = shutil.which('ninja')
        # lit is used by the test suite
        thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
        thirdparty_cmake_args += self.get_pybind11_cmake_args()
        extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
        wheeldir = os.path.dirname(extdir)

        # create build directories
        if not os.path.exists(self.build_temp):
            os.makedirs(self.build_temp)
        # python directories
        python_include_dir = sysconfig.get_path("platinclude")
        cmake_args = [
            "-G", "Ninja",  # Ninja is much faster than make
            "-DCMAKE_MAKE_PROGRAM=" +
            ninja_dir,  # Pass explicit path to ninja otherwise cmake may cache a temporary path
            "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON",
            "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_PYTHON_MODULE=ON",
            "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, "-DPython3_INCLUDE_DIR=" + python_include_dir,
            "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
            "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]),
            "-DTRITON_WHEEL_DIR=" + wheeldir,
            "-DLLVM_MAJOR_VERSION_22_COMPATIBLE=ON"
        ]
        if lit_dir is not None:
            cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
        cmake_args.extend(thirdparty_cmake_args)

        # configuration
        cfg = get_build_type()
        build_args = ["--config", cfg]

        cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"]
        if platform.system() == "Windows":
            cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
        else:
            max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
            build_args += ['-j' + max_jobs]

        if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
            cmake_args += [
                "-DCMAKE_C_COMPILER=clang",
                "-DCMAKE_CXX_COMPILER=clang++",
                "-DCMAKE_LINKER=lld",
                "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
                "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
                "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
            ]

        # Note that asan doesn't work with binaries that use the GPU, so this is
        # only useful for tools like triton-opt that don't run code on the GPU.
        #
        # I tried and gave up getting msan to work.  It seems that libstdc++'s
        # std::string does not play nicely with clang's msan (I didn't try
        # gcc's).  I was unable to configure clang to ignore the error, and I
        # also wasn't able to get libc++ to work, but that doesn't mean it's
        # impossible. :)
        if check_env_flag("TRITON_BUILD_WITH_ASAN"):
            cmake_args += [
                "-DCMAKE_C_FLAGS=-fsanitize=address",
                "-DCMAKE_CXX_FLAGS=-fsanitize=address",
            ]

        # environment variables we will pass through to cmake
        passthrough_args = [
            "TRITON_BUILD_PROTON",
            "TRITON_BUILD_WITH_CCACHE",
            "TRITON_PARALLEL_LINK_JOBS",
        ]
        cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ]

        if check_env_flag("TRITON_BUILD_PROTON", "ON"):  # Default ON
            cmake_args += self.get_proton_cmake_args()

        if is_offline_build():
            # unit test builds fetch googletests from GitHub
            cmake_args += ["-DTRITON_BUILD_UT=OFF"]

        # Allow specifying AscendNPU-IR tag/commit via environment variable
        ascendnpu_ir_tag = os.getenv("ASCENDNPU_IR_TAG")
        if ascendnpu_ir_tag is not None:
            cmake_args += [f"-DASCENDNPU_IR_TAG={ascendnpu_ir_tag}"]

        cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS")
        if cmake_args_append is not None:
            cmake_args += shlex.split(cmake_args_append)

        env = os.environ.copy()
        cmake_dir = get_cmake_dir()
        subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
        update_symlink(Path(self.base_dir) / "compile_commands.json", cmake_dir / "compile_commands.json")
        subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
        subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)

        # Copy triton-mlir-opt tool to extdir for runtime use
        # This tool is needed for converting MLIR to Bytecode
        triton_mlir_opt_src = os.path.join(cmake_dir, "bin", "triton-mlir-opt")
        if os.path.exists(triton_mlir_opt_src):
            triton_mlir_opt_dst = os.path.join(extdir, "triton-mlir-opt")
            shutil.copy2(triton_mlir_opt_src, triton_mlir_opt_dst)
            # Make it executable (Unix-like systems)
            if platform.system() != "Windows":
                os.chmod(triton_mlir_opt_dst, 0o755)
                # Strip debug symbols to reduce binary size (can reduce size by ~80%)
                try:
                    subprocess.check_call(["strip", "--strip-all", triton_mlir_opt_dst])
                    print(f"Stripped triton-mlir-opt to reduce size")
                except (subprocess.CalledProcessError, FileNotFoundError):
                    # strip command not available or failed, continue without stripping
                    pass
            print(f"Copied triton-mlir-opt to {triton_mlir_opt_dst}")

        # Copy triton-opt tool to extdir for runtime use
        # This tool is needed for converting ttir to ttadapter
        triton_opt_src = os.path.join(cmake_dir, "bin", "triton-opt")
        if os.path.exists(triton_opt_src):
            triton_opt_dst = os.path.join(extdir, "triton-opt")
            shutil.copy2(triton_opt_src, triton_opt_dst)
            # Make it executable (Unix-like systems)
            if platform.system() != "Windows":
                os.chmod(triton_opt_dst, 0o755)
                # Strip debug symbols to reduce binary size (can reduce size by ~80%)
                try:
                    subprocess.check_call(["strip", "--strip-all", triton_opt_dst])
                    print(f"Stripped triton-opt to reduce size")
                except (subprocess.CalledProcessError, FileNotFoundError):
                    # strip command not available or failed, continue without stripping
                    pass
            print(f"Copied triton-opt to {triton_opt_dst}")

def download_and_copy_dependencies():
    nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json")
    with open(nvidia_version_path, "r") as nvidia_version_file:
        # parse this json file to get the version of the nvidia toolchain
        NVIDIA_TOOLCHAIN_VERSION = json.load(nvidia_version_file)

    exe_extension = sysconfig.get_config_var("EXE")
    download_and_copy(
        name="nvcc",
        src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}",
        dst_path="bin/ptxas",
        variable="TRITON_PTXAS_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["ptxas"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="cuobjdump",
        src_func=lambda system, arch, version:
        f"cuda_cuobjdump-{system}-{arch}-{version}-archive/bin/cuobjdump{exe_extension}",
        dst_path="bin/cuobjdump",
        variable="TRITON_CUOBJDUMP_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cuobjdump/{system}-{arch}/cuda_cuobjdump-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="nvdisasm",
        src_func=lambda system, arch, version:
        f"cuda_nvdisasm-{system}-{arch}-{version}-archive/bin/nvdisasm{exe_extension}",
        dst_path="bin/nvdisasm",
        variable="TRITON_NVDISASM_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvdisasm/{system}-{arch}/cuda_nvdisasm-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="nvcc",
        src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/include",
        dst_path="include",
        variable="TRITON_CUDACRT_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="cudart",
        src_func=lambda system, arch, version: f"cuda_cudart-{system}-{arch}-{version}-archive/include",
        dst_path="include",
        variable="TRITON_CUDART_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["cudart"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/{system}-{arch}/cuda_cudart-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="cupti",
        src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/include",
        dst_path="include",
        variable="TRITON_CUPTI_INCLUDE_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz",
    )
    download_and_copy(
        name="cupti",
        src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/lib",
        dst_path="lib/cupti",
        variable="TRITON_CUPTI_LIB_PATH",
        version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
        url_func=lambda system, arch, version:
        f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz",
    )


backends = [*BackendInstaller.copy(["ascend", "nvidia", "amd"]), *BackendInstaller.copy_externals()]


def get_package_dirs():
    yield ("", "python")

    for backend in backends:
        # we use symlinks for external plugins
        if backend.is_external:
            continue

        yield (f"triton.backends.{backend.name}", backend.backend_dir)

        if backend.language_dir:
            # Install the contents of each backend's `language` directory into
            # `triton.language.extra`.
            for x in os.listdir(backend.language_dir):
                yield (f"triton.language.extra.{x}", os.path.join(backend.language_dir, x))

        if backend.tools_dir:
            # Install the contents of each backend's `tools` directory into
            # `triton.tools.extra`.
            for x in os.listdir(backend.tools_dir):
                yield (f"triton.tools.extra.{x}", os.path.join(backend.tools_dir, x))

    if check_env_flag("TRITON_BUILD_PROTON", "ON"):  # Default ON
        yield ("triton.profiler", "third_party/proton/proton")
        yield ("triton.profiler.hooks", "third_party/proton/proton/hooks")


def get_packages():
    yield from find_packages(where="python")

    for backend in backends:
        yield f"triton.backends.{backend.name}"

        if backend.language_dir:
            # Install the contents of each backend's `language` directory into
            # `triton.language.extra`.
            for x in os.listdir(backend.language_dir):
                yield f"triton.language.extra.{x}"

        if backend.tools_dir:
            # Install the contents of each backend's `tools` directory into
            # `triton.tools.extra`.
            for x in os.listdir(backend.tools_dir):
                yield f"triton.tools.extra.{x}"

    if check_env_flag("TRITON_BUILD_PROTON", "ON"):  # Default ON
        yield "triton.profiler"


def add_link_to_backends(external_only):
    for backend in backends:
        if external_only and not backend.is_external:
            continue

        update_symlink(backend.install_dir, backend.backend_dir)

        if backend.language_dir:
            # Link the contents of each backend's `language` directory into
            # `triton.language.extra`.
            extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "python", "triton", "language",
                                                     "extra"))
            for x in os.listdir(backend.language_dir):
                src_dir = os.path.join(backend.language_dir, x)
                install_dir = os.path.join(extra_dir, x)
                update_symlink(install_dir, src_dir)

        if backend.tools_dir:
            # Link the contents of each backend's `tools` directory into
            # `triton.tools.extra`.
            extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "python", "triton", "tools", "extra"))
            for x in os.listdir(backend.tools_dir):
                src_dir = os.path.join(backend.tools_dir, x)
                install_dir = os.path.join(extra_dir, x)
                update_symlink(install_dir, src_dir)


def add_link_to_proton():
    proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "third_party", "proton", "proton"))
    proton_install_dir = os.path.join(os.path.dirname(__file__), "python", "triton", "profiler")
    update_symlink(proton_install_dir, proton_dir)


def add_links(external_only):
    add_link_to_backends(external_only=external_only)
    if not external_only and check_env_flag("TRITON_BUILD_PROTON", "ON"):  # Default ON
        add_link_to_proton()


class plugin_bdist_wheel(bdist_wheel):

    def run(self):
        add_links(external_only=True)
        super().run()


class plugin_develop(develop):

    def run(self):
        add_links(external_only=False)
        super().run()


class plugin_editable_wheel(editable_wheel):

    def run(self):
        add_links(external_only=False)
        super().run()


class plugin_egg_info(egg_info):

    def run(self):
        add_links(external_only=True)
        super().run()


class BuildWheel(bdist_wheel):
    def run(self):
        add_links(external_only=True)
        bdist_wheel.run(self)

        if is_manylinux:
            file = glob.glob(os.path.join(self.dist_dir, "*-linux_*.whl"))[0]

            auditwheel_cmd = [
                "auditwheel",
                "-v",
                "repair",
                "--plat",
                f"manylinux_2_27_{platform.machine()}",
                "--plat",
                f"manylinux_2_28_{platform.machine()}",
                "-w",
                self.dist_dir,
                file,
            ]

            try:
                subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE)
            finally:
                os.remove(file)


class plugin_install(install):

    def run(self):
        add_links(external_only=True)
        super().run()


class plugin_sdist(sdist):

    def run(self):
        for backend in backends:
            if backend.is_external:
                raise RuntimeError("sdist cannot be used with TRITON_PLUGIN_DIRS")
        super().run()


def get_entry_points():
    entry_points = {}
    if check_env_flag("TRITON_BUILD_PROTON", "ON"):  # Default ON
        entry_points["console_scripts"] = [
            "proton-viewer = triton.profiler.viewer:main",
            "proton = triton.profiler.proton:main",
        ]
    entry_points["triton.backends"] = [f"{b.name} = triton.backends.{b.name}" for b in backends]
    return entry_points


def get_git_commit_hash(length=8):
    try:
        cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD']
        return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8'))
    except Exception:
        return ""


def get_git_branch():
    try:
        cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
        return subprocess.check_output(cmd).strip().decode('utf-8')
    except Exception:
        return ""


def get_git_version_suffix():
    if not is_git_repo():
        return ""  # Not a git checkout
    branch = get_git_branch()
    if branch.startswith("release"):
        return ""
    else:
        return get_git_commit_hash()


# keep it separate for easy substitution
TRITON_VERSION = "3.5.0" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", "")

# Dynamically define supported Python versions and classifiers
MIN_PYTHON = (3, 9)
MAX_PYTHON = (3, 14)

PYTHON_REQUIRES = f">={MIN_PYTHON[0]}.{MIN_PYTHON[1]},<{MAX_PYTHON[0]}.{MAX_PYTHON[1] + 1}"
BASE_CLASSIFIERS = [
    "Development Status :: 4 - Beta",
    "Intended Audience :: Developers",
    "Topic :: Software Development :: Build Tools",
    "License :: OSI Approved :: MIT License",
]
PYTHON_CLASSIFIERS = [
    f"Programming Language :: Python :: {MIN_PYTHON[0]}.{m}" for m in range(MIN_PYTHON[1], MAX_PYTHON[1] + 1)
]
CLASSIFIERS = BASE_CLASSIFIERS + PYTHON_CLASSIFIERS

# temporary design
# Using version.txt containing version and commitid will be better and
# the version.txt will be converted to versin.py when compilation.
def get_default_version():
    version_file = Path(__file__).parent / "version.txt"
    if version_file.exists():
        return version_file.read_text().strip()
    return "3.5.0"


def get_version():
    version = os.environ.get("TRITON_VERSION", get_default_version()) + os.environ.get(
        "TRITON_WHEEL_VERSION_SUFFIX", ""
    )
    if not is_manylinux:
        version += get_git_commit_hash()

    return version


def get_package_name():
    return os.environ.get("TRITON_WHEEL_NAME", "triton_ascend")


ARCHITECTURE_ALIASES = {
    "x86_64": "x86_64",
    "amd64": "x86_64",
    "i386": "x86_64",
    "i686": "x86_64",
    "arm64": "arm",
    "aarch64": "arm",
    "armv7l": "arm",
    "armv8l": "arm",
    "arm": "arm",
}

ARCHITECTURE_DEPENDENCIES = {
    "x86_64": ["triton==3.5.0"],
    "arm": ["triton==3.5.0"],
}


def get_architecture():
    arch = platform.machine().lower()
    try:
        return ARCHITECTURE_ALIASES[arch]
    except KeyError as exc:
        raise RuntimeError(f"Unsupported CPU architecture: {arch}") from exc


def get_install_requirements():
    install_requires = [
        "attrs==24.2.0",
        "numpy==1.26.4",
        "scipy==1.13.1;python_version<'3.13'",
        "scipy==1.15.1;python_version>='3.13'",
        "decorator==5.1.1",
        "psutil==6.0.0",
        "pytest==8.3.2",
        "pytest-xdist==3.6.1",
        "pyyaml",
        "pybind11",
        "pandas",
    ]
    arch = get_architecture()
    return [*install_requires, *ARCHITECTURE_DEPENDENCIES[arch]]


is_manylinux = check_env_flag("IS_MANYLINUX", "FALSE")
readme = os.path.join(triton_dir, "README.md")
if not os.path.exists(readme):
    raise FileNotFoundError("Unable to find 'README.md'")
with open(readme, encoding="utf-8") as fdesc:
    long_description = fdesc.read()


setup(
    name=get_package_name(),
    version=get_version(),
    author="Philippe Tillet",
    author_email="phil@openai.com",
    description="A language and compiler for custom Deep Learning operations",
    long_description=long_description,
    packages=list(get_packages()),
    package_dir=dict(get_package_dirs()),
    entry_points=get_entry_points(),
    include_package_data=True,
    ext_modules=[CMakeExtension("triton", "triton/_C/")],
    cmdclass={
        "bdist_wheel": BuildWheel,
        "build_ext": CMakeBuild,
        "build_py": CMakeBuildPy,
        "clean": CMakeClean,
        "develop": plugin_develop,
        "editable_wheel": plugin_editable_wheel,
        "egg_info": plugin_egg_info,
        "install": plugin_install,
        "sdist": plugin_sdist,
    },
    zip_safe=False,
    # for PyPI
    keywords=["Compiler", "Deep Learning"],
    url="https://gitcode.com/Ascend/triton-ascend/",
    python_requires=PYTHON_REQUIRES,
    classifiers=CLASSIFIERS,
    test_suite="tests",
    install_requires=get_install_requirements(),
    extras_require={
        "build": [
            "cmake>=3.20,<4.0",
            "lit",
        ],
        "tests": [
            "autopep8",
            "isort",
            "numpy",
            "pytest",
            "pytest-forked",
            "pytest-xdist",
            "scipy>=1.7.1",
            "llnl-hatchet",
        ],
        "tutorials": [
            "matplotlib",
            "pandas",
            "tabulate",
        ],
    },
)