import os
import platform
import re
import contextlib
import shlex
import shutil
import subprocess
import sys
import sysconfig
import tarfile
import zipfile
import urllib.request
import json
import glob
from io import BytesIO
from distutils.command.clean import clean
from pathlib import Path
from typing import Optional
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from setuptools.command.egg_info import egg_info
from setuptools.command.install import install
from setuptools.command.sdist import sdist
from dataclasses import dataclass
import pybind11
try:
from setuptools.command.bdist_wheel import bdist_wheel
except ImportError:
from wheel.bdist_wheel import bdist_wheel
try:
from setuptools.command.editable_wheel import editable_wheel
except ImportError:
class editable_wheel:
pass
sys.path.insert(0, os.path.dirname(__file__))
from python.build_helpers import get_base_dir, get_cmake_dir
triton_dir = os.path.dirname(os.path.abspath(__file__))
os.environ.setdefault("TRITON_BUILD_WITH_CCACHE", "true")
os.environ.setdefault("TRITON_BUILD_WITH_CLANG_LLD", "true")
os.environ.setdefault("TRITON_BUILD_PROTON", "OFF")
os.environ.setdefault("TRITON_WHEEL_NAME", "triton-ascend")
os.environ.setdefault("TRITON_APPEND_CMAKE_ARGS", "-DTRITON_BUILD_UT=OFF")
def is_git_repo():
"""Return True if this file resides in a git repository"""
return (Path(__file__).parent / ".git").is_dir()
@dataclass
class Backend:
name: str
src_dir: str
backend_dir: str
language_dir: Optional[str]
tools_dir: Optional[str]
install_dir: str
is_external: bool
class BackendInstaller:
@staticmethod
def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
if not is_external:
root_dir = "third_party"
assert backend_name in os.listdir(
root_dir), f"{backend_name} is requested for install but not present in {root_dir}"
if is_git_repo():
try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass
backend_src_dir = os.path.join(root_dir, backend_name)
backend_path = os.path.join(backend_src_dir, "backend")
assert os.path.exists(backend_path), f"{backend_path} does not exist!"
language_dir = os.path.join(backend_src_dir, "language")
if not os.path.exists(language_dir):
language_dir = None
tools_dir = os.path.join(backend_src_dir, "tools")
if not os.path.exists(tools_dir):
tools_dir = None
for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"
install_dir = os.path.join(os.path.dirname(__file__), "python", "triton", "backends", backend_name)
return Backend(name=backend_name, src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
tools_dir=tools_dir, install_dir=install_dir, is_external=is_external)
@staticmethod
def copy(active):
return [BackendInstaller.prepare(backend) for backend in active]
@staticmethod
def copy_externals():
backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
if backend_dirs is None:
return []
backend_dirs = backend_dirs.strip().split(";")
backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs]
return [
BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
]
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
def get_build_type():
if check_env_flag("DEBUG"):
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
elif check_env_flag("TRITON_BUILD_WITH_O1"):
return "TritonBuildWithO1"
else:
return "TritonRelBuildWithAsserts"
def get_env_with_keys(key: list):
for k in key:
if k in os.environ:
return os.environ[k]
return ""
def is_offline_build() -> bool:
"""
Downstream projects and distributions which bootstrap their own dependencies from scratch
and run builds in offline sandboxes
may set `TRITON_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading
pinned dependencies from the internet or at using dependencies vendored in-tree.
Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).
Missing dependencies lead to an early abortion.
Dependencies' compatibility is not verified.
Note that this flag isn't tested by the CI and does not provide any guarantees.
"""
return check_env_flag("TRITON_OFFLINE_BUILD", "")
@dataclass
class Package:
package: str
name: str
url: str
include_flag: str
lib_flag: str
syspath_var_name: str
sym_name: Optional[str] = None
def get_json_package_info():
url = "https://github.com/nlohmann/json/releases/download/v3.11.3/include.zip"
return Package("json", "", url, "JSON_INCLUDE_DIR", "", "JSON_SYSPATH")
def is_linux_os(os_id):
if os.path.exists("/etc/os-release"):
with open("/etc/os-release", "r") as f:
os_release_content = f.read()
return f'ID="{os_id}"' in os_release_content
return False
def get_llvm_package_info():
system = platform.system()
try:
arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
except KeyError:
arch = platform.machine()
if (env_system_suffix := os.environ.get("TRITON_LLVM_SYSTEM_SUFFIX", None)):
system_suffix = env_system_suffix
elif system == "Darwin":
system_suffix = f"macos-{arch}"
elif system == "Linux":
if arch == 'arm64' and is_linux_os('almalinux'):
system_suffix = 'almalinux-arm64'
elif arch == "arm64":
system_suffix = 'ubuntu-arm64'
elif arch == 'x64':
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
vglibc = vglibc[0] * 100 + vglibc[1]
if vglibc > 228:
system_suffix = "ubuntu-x64"
elif vglibc > 217:
system_suffix = "almalinux-x64"
else:
system_suffix = "centos-x64"
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
llvm_hash_path = os.path.join(get_base_dir(), "cmake", "llvm-hash.txt")
with open(llvm_hash_path, "r") as llvm_hash_file:
rev = llvm_hash_file.read(8)
name = f"llvm-{rev}-{system_suffix}"
sym_name = f"llvm-{system_suffix}"
url = f"https://triton-ascend-artifacts.obs.myhuaweicloud.com/llvm-builds/{name}.tar.gz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH", sym_name=sym_name)
def open_url(url):
user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'
headers = {
'User-Agent': user_agent,
}
request = urllib.request.Request(url, None, headers)
return urllib.request.urlopen(request, timeout=300)
def get_triton_cache_path():
user_home = os.getenv("TRITON_HOME")
if not user_home:
user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None
if not user_home:
raise RuntimeError("Could not find user home directory")
return os.path.join(user_home, ".triton")
def update_symlink(link_path, source_path):
source_path = Path(source_path)
link_path = Path(link_path)
if link_path.is_symlink():
link_path.unlink()
elif link_path.exists():
shutil.rmtree(link_path)
print(f"creating symlink: {link_path} -> {source_path}", file=sys.stderr)
link_path.absolute().parent.mkdir(parents=True, exist_ok=True)
link_path.symlink_to(source_path.absolute(), target_is_directory=True)
def get_thirdparty_packages(packages: list):
triton_cache_path = get_triton_cache_path()
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
if os.environ.get(p.syspath_var_name):
package_dir = os.environ[p.syspath_var_name]
version_file_path = os.path.join(package_dir, "version.txt")
input_defined = p.syspath_var_name in os.environ
input_exists = os.path.exists(version_file_path)
input_compatible = input_exists and Path(version_file_path).read_text() == p.url
if is_offline_build() and not input_defined:
raise RuntimeError(f"Requested an offline build but {p.syspath_var_name} is not set")
if not is_offline_build() and not input_defined and not input_compatible:
with contextlib.suppress(Exception):
shutil.rmtree(package_root_dir)
os.makedirs(package_root_dir, exist_ok=True)
print(f'downloading and extracting {p.url} ...')
with open_url(p.url) as response:
if p.url.endswith(".zip"):
file_bytes = BytesIO(response.read())
with zipfile.ZipFile(file_bytes, "r") as file:
file.extractall(path=package_root_dir)
else:
with tarfile.open(fileobj=response, mode="r|*") as file:
file.extractall(path=package_root_dir)
with open(os.path.join(package_dir, "version.txt"), "w") as f:
f.write(p.url)
if p.include_flag:
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
if p.lib_flag:
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
if p.syspath_var_name:
thirdparty_cmake_args.append(f"-D{p.syspath_var_name}={package_dir}")
if p.sym_name is not None:
sym_link_path = os.path.join(package_root_dir, p.sym_name)
update_symlink(sym_link_path, package_dir)
return thirdparty_cmake_args
def download_and_copy(name, src_func, dst_path, variable, version, url_func):
if is_offline_build():
return
triton_cache_path = get_triton_cache_path()
if variable in os.environ:
return
base_dir = os.path.dirname(__file__)
system = platform.system()
arch = platform.machine()
arch = {"arm64": "sbsa", "aarch64": "sbsa"}.get(arch, arch)
supported = {"Linux": "linux", "Darwin": "linux"}
url = url_func(supported[system], arch, version)
src_path = src_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name)
dst_path = os.path.join(base_dir, "third_party", "nvidia", "backend", dst_path)
src_path = os.path.join(tmp_path, src_path)
download = not os.path.exists(src_path)
if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None:
curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
curr_version = re.search(r"V([.|\d]+)", curr_version)
assert curr_version is not None, f"No version information for {dst_path}"
download = download or curr_version.group(1) != version
if download:
print(f'downloading and extracting {url} ...')
file = tarfile.open(fileobj=open_url(url), mode="r|*")
file.extractall(path=tmp_path)
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
print(f'copy {src_path} to {dst_path} ...')
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
else:
shutil.copy(src_path, dst_path)
class CMakeClean(clean):
def initialize_options(self):
clean.initialize_options(self)
self.build_temp = get_cmake_dir()
class CMakeBuildPy(build_py):
def run(self) -> None:
self.run_command('build_ext')
return super().run()
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
user_options = build_ext.user_options + \
[('base-dir=', None, 'base directory of Triton')]
def initialize_options(self):
build_ext.initialize_options(self)
self.base_dir = get_base_dir()
def finalize_options(self):
build_ext.finalize_options(self)
def setup_coverage_env(self):
"""Setting environment variables required for the hitest coverage tool"""
hitest_home = os.getenv('HITEST_HOME', '/opt/hitest/linux_avatar_x86_64')
hitest_user_account = os.getenv('HITEST_USER_ACCOUNT', 'a00000000')
lltcov_rootpath = os.getenv('LLTCOV_ROOTPATH', '/opt/covdata')
coverage_env_vars = {
'HitestHome': hitest_home,
'isOverlappedCompile': '0',
'PlatformToken': 'BOARD',
'gcovmode': '0',
'TimerPolicy': '1',
'TimeInterval': '60',
'SignalPolicy': '1',
'SignalNUM': '34',
'lltwrapper_cfg': '0',
'HITEST_AGENT_INSIDE': '1',
'USE_HLLT_COVERAGE': '1',
'USE_HLLT_TESTCASE': '0',
'simplemode': '0',
'ncs_coverage_stub_mold': '1',
'HITEST_ENABLE_SOKCET': '0',
'hitest_disable_cfg': '0',
'hitest_disable_dfg': '1',
'hitest_disable_ir': '1',
'HITEST_DISABLE_MACRO': '0',
'HITEST_REMOVE_INCLUDE_DIR': '0',
'HITEST_AGENT_SET_THREADNAME_PRCTL': '1',
'HITEST_INST_HEADER_FILE': '0',
'HITEST_USER_ACCOUNT': hitest_user_account,
'lltcovRootpath': lltcov_rootpath,
'HITEST_COVSTUB_ROOT_DIR': f'{hitest_home}/apache-tomcat-8.0.39/webapps/datasource/Container_Default/base',
'HITEST_EXEC_CMD_WITH_FILE': '1',
'HITEST_PRINT_LOG_ENABLE': '1',
}
for key, value in coverage_env_vars.items():
os.environ[key] = value
current_path = os.environ.get('PATH', '')
os.environ['PATH'] = f'{hitest_home}:{current_path}'
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'
print(f"The currently set environment variables for the hitest coverage tool are read.")
print(f" HitestHome: {hitest_home} (environment variables HITEST_HOME)")
print(f" HITEST_USER_ACCOUNT: {hitest_user_account} (environment variables HITEST_USER_ACCOUNT)")
print(f" lltcovRootpath: {lltcov_rootpath} (environment variables LLTCOV_ROOTPATH)")
current_path = os.environ.get('PATH', '')
os.environ['PATH'] = f'{hitest_home}:{current_path}'
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'
def run(self):
download_and_copy_dependencies()
try:
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
if (cmake_major, cmake_minor) < (3, 20):
raise RuntimeError("CMake >= 3.20 is required")
enable_hitest = os.getenv('TRITON_ENABLE_COVERAGE_HITEST', '0').lower() in ('1', 'on', 'true')
if enable_hitest:
self.setup_coverage_env()
current_append = os.environ.get('TRITON_APPEND_CMAKE_ARGS', '')
if current_append:
os.environ['TRITON_APPEND_CMAKE_ARGS'] = current_append + " -DTRITON_ENABLE_COVERAGE_HITEST=ON"
else:
os.environ['TRITON_APPEND_CMAKE_ARGS'] = "-DTRITON_ENABLE_COVERAGE_HITEST=ON"
else:
for key in list(os.environ.keys()):
if key.startswith('HITEST_') or key in ['HitestHome', 'lltcovRootpath']:
del os.environ[key]
for ext in self.extensions:
self.build_extension(ext)
def get_pybind11_cmake_args(self):
pybind11_sys_path = get_env_with_keys(["PYBIND11_SYSPATH"])
if pybind11_sys_path:
pybind11_include_dir = os.path.join(pybind11_sys_path, "include")
else:
pybind11_include_dir = pybind11.get_include()
return [f"-Dpybind11_INCLUDE_DIR='{pybind11_include_dir}'", f"-Dpybind11_DIR='{pybind11.get_cmake_dir()}'"]
def get_proton_cmake_args(self):
cmake_args = get_thirdparty_packages([get_json_package_info()])
cmake_args += self.get_pybind11_cmake_args()
cupti_include_dir = get_env_with_keys(["TRITON_CUPTI_INCLUDE_PATH"])
if cupti_include_dir == "":
cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include")
cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir]
roctracer_include_dir = get_env_with_keys(["TRITON_ROCTRACER_INCLUDE_PATH"])
if roctracer_include_dir == "":
roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include")
cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir]
return cmake_args
def build_extension(self, ext):
lit_dir = shutil.which('lit')
ninja_dir = shutil.which('ninja')
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
thirdparty_cmake_args += self.get_pybind11_cmake_args()
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
wheeldir = os.path.dirname(extdir)
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
python_include_dir = sysconfig.get_path("platinclude")
cmake_args = [
"-G", "Ninja",
"-DCMAKE_MAKE_PROGRAM=" +
ninja_dir,
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_PYTHON_MODULE=ON",
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, "-DPython3_INCLUDE_DIR=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]),
"-DTRITON_WHEEL_DIR=" + wheeldir,
"-DLLVM_MAJOR_VERSION_22_COMPATIBLE=ON"
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
cfg = get_build_type()
build_args = ["--config", cfg]
cmake_args += [f"-DCMAKE_BUILD_TYPE={cfg}"]
if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
else:
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
build_args += ['-j' + max_jobs]
if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
cmake_args += [
"-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
]
if check_env_flag("TRITON_BUILD_WITH_ASAN"):
cmake_args += [
"-DCMAKE_C_FLAGS=-fsanitize=address",
"-DCMAKE_CXX_FLAGS=-fsanitize=address",
]
passthrough_args = [
"TRITON_BUILD_PROTON",
"TRITON_BUILD_WITH_CCACHE",
"TRITON_PARALLEL_LINK_JOBS",
]
cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ]
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
cmake_args += self.get_proton_cmake_args()
if is_offline_build():
cmake_args += ["-DTRITON_BUILD_UT=OFF"]
ascendnpu_ir_tag = os.getenv("ASCENDNPU_IR_TAG")
if ascendnpu_ir_tag is not None:
cmake_args += [f"-DASCENDNPU_IR_TAG={ascendnpu_ir_tag}"]
cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS")
if cmake_args_append is not None:
cmake_args += shlex.split(cmake_args_append)
env = os.environ.copy()
cmake_dir = get_cmake_dir()
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
update_symlink(Path(self.base_dir) / "compile_commands.json", cmake_dir / "compile_commands.json")
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
triton_mlir_opt_src = os.path.join(cmake_dir, "bin", "triton-mlir-opt")
if os.path.exists(triton_mlir_opt_src):
triton_mlir_opt_dst = os.path.join(extdir, "triton-mlir-opt")
shutil.copy2(triton_mlir_opt_src, triton_mlir_opt_dst)
if platform.system() != "Windows":
os.chmod(triton_mlir_opt_dst, 0o755)
try:
subprocess.check_call(["strip", "--strip-all", triton_mlir_opt_dst])
print(f"Stripped triton-mlir-opt to reduce size")
except (subprocess.CalledProcessError, FileNotFoundError):
pass
print(f"Copied triton-mlir-opt to {triton_mlir_opt_dst}")
triton_opt_src = os.path.join(cmake_dir, "bin", "triton-opt")
if os.path.exists(triton_opt_src):
triton_opt_dst = os.path.join(extdir, "triton-opt")
shutil.copy2(triton_opt_src, triton_opt_dst)
if platform.system() != "Windows":
os.chmod(triton_opt_dst, 0o755)
try:
subprocess.check_call(["strip", "--strip-all", triton_opt_dst])
print(f"Stripped triton-opt to reduce size")
except (subprocess.CalledProcessError, FileNotFoundError):
pass
print(f"Copied triton-opt to {triton_opt_dst}")
def download_and_copy_dependencies():
nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json")
with open(nvidia_version_path, "r") as nvidia_version_file:
NVIDIA_TOOLCHAIN_VERSION = json.load(nvidia_version_file)
exe_extension = sysconfig.get_config_var("EXE")
download_and_copy(
name="nvcc",
src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}",
dst_path="bin/ptxas",
variable="TRITON_PTXAS_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="cuobjdump",
src_func=lambda system, arch, version:
f"cuda_cuobjdump-{system}-{arch}-{version}-archive/bin/cuobjdump{exe_extension}",
dst_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cuobjdump/{system}-{arch}/cuda_cuobjdump-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="nvdisasm",
src_func=lambda system, arch, version:
f"cuda_nvdisasm-{system}-{arch}-{version}-archive/bin/nvdisasm{exe_extension}",
dst_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvdisasm/{system}-{arch}/cuda_nvdisasm-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="nvcc",
src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/include",
dst_path="include",
variable="TRITON_CUDACRT_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="cudart",
src_func=lambda system, arch, version: f"cuda_cudart-{system}-{arch}-{version}-archive/include",
dst_path="include",
variable="TRITON_CUDART_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cudart"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/{system}-{arch}/cuda_cudart-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="cupti",
src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/include",
dst_path="include",
variable="TRITON_CUPTI_INCLUDE_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz",
)
download_and_copy(
name="cupti",
src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/lib",
dst_path="lib/cupti",
variable="TRITON_CUPTI_LIB_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
url_func=lambda system, arch, version:
f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz",
)
backends = [*BackendInstaller.copy(["ascend", "nvidia", "amd"]), *BackendInstaller.copy_externals()]
def get_package_dirs():
yield ("", "python")
for backend in backends:
if backend.is_external:
continue
yield (f"triton.backends.{backend.name}", backend.backend_dir)
if backend.language_dir:
for x in os.listdir(backend.language_dir):
yield (f"triton.language.extra.{x}", os.path.join(backend.language_dir, x))
if backend.tools_dir:
for x in os.listdir(backend.tools_dir):
yield (f"triton.tools.extra.{x}", os.path.join(backend.tools_dir, x))
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
yield ("triton.profiler", "third_party/proton/proton")
yield ("triton.profiler.hooks", "third_party/proton/proton/hooks")
def get_packages():
yield from find_packages(where="python")
for backend in backends:
yield f"triton.backends.{backend.name}"
if backend.language_dir:
for x in os.listdir(backend.language_dir):
yield f"triton.language.extra.{x}"
if backend.tools_dir:
for x in os.listdir(backend.tools_dir):
yield f"triton.tools.extra.{x}"
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
yield "triton.profiler"
def add_link_to_backends(external_only):
for backend in backends:
if external_only and not backend.is_external:
continue
update_symlink(backend.install_dir, backend.backend_dir)
if backend.language_dir:
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "python", "triton", "language",
"extra"))
for x in os.listdir(backend.language_dir):
src_dir = os.path.join(backend.language_dir, x)
install_dir = os.path.join(extra_dir, x)
update_symlink(install_dir, src_dir)
if backend.tools_dir:
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "python", "triton", "tools", "extra"))
for x in os.listdir(backend.tools_dir):
src_dir = os.path.join(backend.tools_dir, x)
install_dir = os.path.join(extra_dir, x)
update_symlink(install_dir, src_dir)
def add_link_to_proton():
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "third_party", "proton", "proton"))
proton_install_dir = os.path.join(os.path.dirname(__file__), "python", "triton", "profiler")
update_symlink(proton_install_dir, proton_dir)
def add_links(external_only):
add_link_to_backends(external_only=external_only)
if not external_only and check_env_flag("TRITON_BUILD_PROTON", "ON"):
add_link_to_proton()
class plugin_bdist_wheel(bdist_wheel):
def run(self):
add_links(external_only=True)
super().run()
class plugin_develop(develop):
def run(self):
add_links(external_only=False)
super().run()
class plugin_editable_wheel(editable_wheel):
def run(self):
add_links(external_only=False)
super().run()
class plugin_egg_info(egg_info):
def run(self):
add_links(external_only=True)
super().run()
class BuildWheel(bdist_wheel):
def run(self):
add_links(external_only=True)
bdist_wheel.run(self)
if is_manylinux:
file = glob.glob(os.path.join(self.dist_dir, "*-linux_*.whl"))[0]
auditwheel_cmd = [
"auditwheel",
"-v",
"repair",
"--plat",
f"manylinux_2_27_{platform.machine()}",
"--plat",
f"manylinux_2_28_{platform.machine()}",
"-w",
self.dist_dir,
file,
]
try:
subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE)
finally:
os.remove(file)
class plugin_install(install):
def run(self):
add_links(external_only=True)
super().run()
class plugin_sdist(sdist):
def run(self):
for backend in backends:
if backend.is_external:
raise RuntimeError("sdist cannot be used with TRITON_PLUGIN_DIRS")
super().run()
def get_entry_points():
entry_points = {}
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
entry_points["console_scripts"] = [
"proton-viewer = triton.profiler.viewer:main",
"proton = triton.profiler.proton:main",
]
entry_points["triton.backends"] = [f"{b.name} = triton.backends.{b.name}" for b in backends]
return entry_points
def get_git_commit_hash(length=8):
try:
cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD']
return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8'))
except Exception:
return ""
def get_git_branch():
try:
cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
return subprocess.check_output(cmd).strip().decode('utf-8')
except Exception:
return ""
def get_git_version_suffix():
if not is_git_repo():
return ""
branch = get_git_branch()
if branch.startswith("release"):
return ""
else:
return get_git_commit_hash()
TRITON_VERSION = "3.5.0" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", "")
MIN_PYTHON = (3, 9)
MAX_PYTHON = (3, 14)
PYTHON_REQUIRES = f">={MIN_PYTHON[0]}.{MIN_PYTHON[1]},<{MAX_PYTHON[0]}.{MAX_PYTHON[1] + 1}"
BASE_CLASSIFIERS = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
]
PYTHON_CLASSIFIERS = [
f"Programming Language :: Python :: {MIN_PYTHON[0]}.{m}" for m in range(MIN_PYTHON[1], MAX_PYTHON[1] + 1)
]
CLASSIFIERS = BASE_CLASSIFIERS + PYTHON_CLASSIFIERS
def get_default_version():
version_file = Path(__file__).parent / "version.txt"
if version_file.exists():
return version_file.read_text().strip()
return "3.5.0"
def get_version():
version = os.environ.get("TRITON_VERSION", get_default_version()) + os.environ.get(
"TRITON_WHEEL_VERSION_SUFFIX", ""
)
if not is_manylinux:
version += get_git_commit_hash()
return version
def get_package_name():
return os.environ.get("TRITON_WHEEL_NAME", "triton_ascend")
ARCHITECTURE_ALIASES = {
"x86_64": "x86_64",
"amd64": "x86_64",
"i386": "x86_64",
"i686": "x86_64",
"arm64": "arm",
"aarch64": "arm",
"armv7l": "arm",
"armv8l": "arm",
"arm": "arm",
}
ARCHITECTURE_DEPENDENCIES = {
"x86_64": ["triton==3.5.0"],
"arm": ["triton==3.5.0"],
}
def get_architecture():
arch = platform.machine().lower()
try:
return ARCHITECTURE_ALIASES[arch]
except KeyError as exc:
raise RuntimeError(f"Unsupported CPU architecture: {arch}") from exc
def get_install_requirements():
install_requires = [
"attrs==24.2.0",
"numpy==1.26.4",
"scipy==1.13.1;python_version<'3.13'",
"scipy==1.15.1;python_version>='3.13'",
"decorator==5.1.1",
"psutil==6.0.0",
"pytest==8.3.2",
"pytest-xdist==3.6.1",
"pyyaml",
"pybind11",
"pandas",
]
arch = get_architecture()
return [*install_requires, *ARCHITECTURE_DEPENDENCIES[arch]]
is_manylinux = check_env_flag("IS_MANYLINUX", "FALSE")
readme = os.path.join(triton_dir, "README.md")
if not os.path.exists(readme):
raise FileNotFoundError("Unable to find 'README.md'")
with open(readme, encoding="utf-8") as fdesc:
long_description = fdesc.read()
setup(
name=get_package_name(),
version=get_version(),
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description=long_description,
packages=list(get_packages()),
package_dir=dict(get_package_dirs()),
entry_points=get_entry_points(),
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={
"bdist_wheel": BuildWheel,
"build_ext": CMakeBuild,
"build_py": CMakeBuildPy,
"clean": CMakeClean,
"develop": plugin_develop,
"editable_wheel": plugin_editable_wheel,
"egg_info": plugin_egg_info,
"install": plugin_install,
"sdist": plugin_sdist,
},
zip_safe=False,
keywords=["Compiler", "Deep Learning"],
url="https://gitcode.com/Ascend/triton-ascend/",
python_requires=PYTHON_REQUIRES,
classifiers=CLASSIFIERS,
test_suite="tests",
install_requires=get_install_requirements(),
extras_require={
"build": [
"cmake>=3.20,<4.0",
"lit",
],
"tests": [
"autopep8",
"isort",
"numpy",
"pytest",
"pytest-forked",
"pytest-xdist",
"scipy>=1.7.1",
"llnl-hatchet",
],
"tutorials": [
"matplotlib",
"pandas",
"tabulate",
],
},
)