import os
import platform
import re
import contextlib
import shlex
import shutil
import subprocess
import sys
import sysconfig
import tarfile
import zipfile
import urllib.request
import json
import glob
from io import BytesIO
from distutils.command.clean import clean
from pathlib import Path
from typing import List, NamedTuple, Optional
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from dataclasses import dataclass
from distutils.command.install import install
from setuptools.command.develop import develop
from setuptools.command.egg_info import egg_info
from wheel.bdist_wheel import bdist_wheel
import pybind11
triton_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
os.environ.setdefault("TRITON_BUILD_WITH_CCACHE", "true")
os.environ.setdefault("TRITON_BUILD_WITH_CLANG_LLD", "true")
os.environ.setdefault("TRITON_BUILD_PROTON", "OFF")
os.environ.setdefault("TRITON_WHEEL_NAME", "triton-ascend")
os.environ.setdefault("TRITON_APPEND_CMAKE_ARGS", "-DTRITON_BUILD_UT=OFF")
@dataclass
class Backend:
name: str
package_data: List[str]
language_package_data: List[str]
src_dir: str
backend_dir: str
language_dir: Optional[str]
install_dir: str
is_external: bool
class BackendInstaller:
@staticmethod
def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False):
if not is_external:
root_dir = os.path.join(os.pardir, "third_party")
assert backend_name in os.listdir(
root_dir), f"{backend_name} is requested for install but not present in {root_dir}"
try:
subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True,
stdout=subprocess.DEVNULL, cwd=root_dir)
except subprocess.CalledProcessError:
pass
except FileNotFoundError:
pass
backend_src_dir = os.path.join(root_dir, backend_name)
backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend"))
assert os.path.exists(backend_path), f"{backend_path} does not exist!"
language_dir = os.path.abspath(os.path.join(backend_src_dir, "language"))
if not os.path.exists(language_dir):
language_dir = None
for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"
install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name)
package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)]
language_package_data = []
if language_dir is not None:
language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)]
return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data,
src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
install_dir=install_dir, is_external=is_external)
@staticmethod
def copy(active):
return [BackendInstaller.prepare(backend) for backend in active]
@staticmethod
def copy_externals():
backend_dirs = os.getenv("TRITON_PLUGIN_DIRS")
if backend_dirs is None:
return []
backend_dirs = backend_dirs.strip().split(";")
backend_names = [Path(os.path.join(dir, "backend", "name.conf")).read_text().strip() for dir in backend_dirs]
return [
BackendInstaller.prepare(backend_name, backend_src_dir=backend_src_dir, is_external=True)
for backend_name, backend_src_dir in zip(backend_names, backend_dirs)
]
def check_env_flag(name: str, default: str = "") -> bool:
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
def get_build_type():
if check_env_flag("DEBUG"):
return "Debug"
elif check_env_flag("REL_WITH_DEB_INFO"):
return "RelWithDebInfo"
elif check_env_flag("TRITON_REL_BUILD_WITH_ASSERTS"):
return "TritonRelBuildWithAsserts"
elif check_env_flag("TRITON_BUILD_WITH_O1"):
return "TritonBuildWithO1"
else:
return "TritonRelBuildWithAsserts"
def get_env_with_keys(key: list):
for k in key:
if k in os.environ:
return os.environ[k]
return ""
def is_offline_build() -> bool:
"""
Downstream projects and distributions which bootstrap their own dependencies from scratch
and run builds in offline sandboxes
may set `TRITON_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading
pinned dependencies from the internet or at using dependencies vendored in-tree.
Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`).
Missing dependencies lead to an early abortion.
Dependencies' compatibility is not verified.
Note that this flag isn't tested by the CI and does not provide any guarantees.
"""
return check_env_flag("TRITON_OFFLINE_BUILD", "")
class Package(NamedTuple):
package: str
name: str
url: str
include_flag: str
lib_flag: str
syspath_var_name: str
def get_json_package_info():
url = "https://github.com/nlohmann/json/releases/download/v3.11.3/include.zip"
return Package("json", "", url, "JSON_INCLUDE_DIR", "", "JSON_SYSPATH")
def is_linux_os(os_id):
if os.path.exists("/etc/os-release"):
with open("/etc/os-release", "r") as f:
os_release_content = f.read()
return f'ID="{os_id}"' in os_release_content
return False
def get_llvm_package_info():
system = platform.system()
try:
arch = {"x86_64": "x64", "arm64": "arm64", "aarch64": "arm64"}[platform.machine()]
except KeyError:
arch = platform.machine()
if system == "Darwin":
system_suffix = f"macos-{arch}"
elif system == "Linux":
if arch == 'arm64' and is_linux_os('almalinux'):
system_suffix = 'almalinux-arm64'
elif arch == "arm64":
system_suffix = 'ubuntu-arm64'
elif arch == 'x64':
vglibc = tuple(map(int, platform.libc_ver()[1].split('.')))
vglibc = vglibc[0] * 100 + vglibc[1]
if vglibc > 228:
system_suffix = "ubuntu-x64"
elif vglibc > 217:
system_suffix = "almalinux-x64"
else:
system_suffix = "centos-x64"
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
else:
print(
f"LLVM pre-compiled image is not available for {system}-{arch}. Proceeding with user-configured LLVM from source build."
)
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
llvm_hash_path = os.path.join(get_base_dir(), "cmake", "llvm-hash.txt")
with open(llvm_hash_path, "r") as llvm_hash_file:
rev = llvm_hash_file.read(8)
name = f"llvm-{rev}-{system_suffix}"
url = f"https://triton-ascend-artifacts.obs.myhuaweicloud.com/llvm-builds/{name}.tar.gz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
def open_url(url):
user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0'
headers = {
'User-Agent': user_agent,
}
request = urllib.request.Request(url, None, headers)
return urllib.request.urlopen(request, timeout=300)
def get_triton_cache_path():
user_home = os.getenv("TRITON_HOME")
if not user_home:
user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None
if not user_home:
raise RuntimeError("Could not find user home directory")
return os.path.join(user_home, ".triton")
def get_thirdparty_packages(packages: list):
triton_cache_path = get_triton_cache_path()
thirdparty_cmake_args = []
for p in packages:
package_root_dir = os.path.join(triton_cache_path, p.package)
package_dir = os.path.join(package_root_dir, p.name)
if os.environ.get(p.syspath_var_name):
package_dir = os.environ[p.syspath_var_name]
version_file_path = os.path.join(package_dir, "version.txt")
input_defined = p.syspath_var_name in os.environ
input_exists = os.path.exists(version_file_path)
input_compatible = input_exists and Path(version_file_path).read_text() == p.url
if is_offline_build() and not input_defined:
raise RuntimeError(f"Requested an offline build but {p.syspath_var_name} is not set")
if not is_offline_build() and not input_defined and not input_compatible:
with contextlib.suppress(Exception):
shutil.rmtree(package_root_dir)
os.makedirs(package_root_dir, exist_ok=True)
print(f'downloading and extracting {p.url} ...')
with open_url(p.url) as response:
if p.url.endswith(".zip"):
file_bytes = BytesIO(response.read())
with zipfile.ZipFile(file_bytes, "r") as file:
file.extractall(path=package_root_dir)
else:
with tarfile.open(fileobj=response, mode="r|*") as file:
file.extractall(path=package_root_dir)
with open(os.path.join(package_dir, "version.txt"), "w") as f:
f.write(p.url)
if p.include_flag:
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
if p.lib_flag:
thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib")
return thirdparty_cmake_args
def download_and_copy(name, src_path, dst_path, variable, version, url_func):
if is_offline_build():
return
triton_cache_path = get_triton_cache_path()
if variable in os.environ:
return
base_dir = os.path.dirname(__file__)
system = platform.system()
try:
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
supported = {"Linux": "linux", "Darwin": "linux"}
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name)
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path)
platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
src_path = src_path(platform_name, version) if callable(src_path) else src_path
src_path = os.path.join(tmp_path, src_path)
download = not os.path.exists(src_path)
if os.path.exists(dst_path) and system == "Linux" and shutil.which(dst_path) is not None:
curr_version = subprocess.check_output([dst_path, "--version"]).decode("utf-8").strip()
curr_version = re.search(r"V([.|\d]+)", curr_version).group(1)
download = download or curr_version != version
if download:
print(f'downloading and extracting {url} ...')
file = tarfile.open(fileobj=open_url(url), mode="r|*")
file.extractall(path=tmp_path)
os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
print(f'copy {src_path} to {dst_path} ...')
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
else:
shutil.copy(src_path, dst_path)
def get_base_dir():
return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
def get_cmake_dir():
plat_name = sysconfig.get_platform()
python_version = sysconfig.get_python_version()
dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}"
cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name
cmake_dir.mkdir(parents=True, exist_ok=True)
return cmake_dir
class CMakeClean(clean):
def initialize_options(self):
clean.initialize_options(self)
self.build_temp = get_cmake_dir()
class CMakeBuildPy(build_py):
def run(self) -> None:
self.run_command('build_ext')
return super().run()
class CMakeExtension(Extension):
def __init__(self, name, path, sourcedir=""):
Extension.__init__(self, name, sources=[])
self.sourcedir = os.path.abspath(sourcedir)
self.path = path
class CMakeBuild(build_ext):
user_options = build_ext.user_options + \
[('base-dir=', None, 'base directory of Triton')]
def initialize_options(self):
build_ext.initialize_options(self)
self.base_dir = get_base_dir()
def finalize_options(self):
build_ext.finalize_options(self)
def setup_coverage_env(self):
"""Setting environment variables required for the hitest coverage tool"""
hitest_home = os.getenv('HITEST_HOME', '/opt/hitest/linux_avatar_x86_64')
hitest_user_account = os.getenv('HITEST_USER_ACCOUNT', 'a00000000')
lltcov_rootpath = os.getenv('LLTCOV_ROOTPATH', '/opt/covdata')
coverage_env_vars = {
'HitestHome': hitest_home,
'isOverlappedCompile': '0',
'PlatformToken': 'BOARD',
'gcovmode': '0',
'TimerPolicy': '1',
'TimeInterval': '60',
'SignalPolicy': '1',
'SignalNUM': '34',
'lltwrapper_cfg': '0',
'HITEST_AGENT_INSIDE': '1',
'USE_HLLT_COVERAGE': '1',
'USE_HLLT_TESTCASE': '0',
'simplemode': '0',
'ncs_coverage_stub_mold': '1',
'HITEST_ENABLE_SOKCET': '0',
'hitest_disable_cfg': '0',
'hitest_disable_dfg': '1',
'hitest_disable_ir': '1',
'HITEST_DISABLE_MACRO': '0',
'HITEST_REMOVE_INCLUDE_DIR': '0',
'HITEST_AGENT_SET_THREADNAME_PRCTL': '1',
'HITEST_INST_HEADER_FILE': '0',
'HITEST_USER_ACCOUNT': hitest_user_account,
'lltcovRootpath': lltcov_rootpath,
'HITEST_COVSTUB_ROOT_DIR': f'{hitest_home}/apache-tomcat-8.0.39/webapps/datasource/Container_Default/base',
'HITEST_EXEC_CMD_WITH_FILE': '1',
'HITEST_PRINT_LOG_ENABLE': '1',
}
for key, value in coverage_env_vars.items():
os.environ[key] = value
current_path = os.environ.get('PATH', '')
os.environ['PATH'] = f'{hitest_home}:{current_path}'
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'
print(f"The currently set environment variables for the hitest coverage tool are read.")
print(f" HitestHome: {hitest_home} (environment variables HITEST_HOME)")
print(f" HITEST_USER_ACCOUNT: {hitest_user_account} (environment variables HITEST_USER_ACCOUNT)")
print(f" lltcovRootpath: {lltcov_rootpath} (environment variables LLTCOV_ROOTPATH)")
current_path = os.environ.get('PATH', '')
os.environ['PATH'] = f'{hitest_home}:{current_path}'
current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
os.environ['LD_LIBRARY_PATH'] = f'{hitest_home}:{current_ld_path}'
def run(self):
try:
out = subprocess.check_output(["cmake", "--version"])
except OSError:
raise RuntimeError("CMake must be installed to build the following extensions: " +
", ".join(e.name for e in self.extensions))
match = re.search(r"version\s*(?P<major>\d+)\.(?P<minor>\d+)([\d.]+)?", out.decode())
cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor"))
if (cmake_major, cmake_minor) < (3, 18):
raise RuntimeError("CMake >= 3.18.0 is required")
enable_hitest = os.getenv('TRITON_ENABLE_COVERAGE_HITEST', '0').lower() in ('1', 'on', 'true')
if enable_hitest:
self.setup_coverage_env()
current_append = os.environ.get('TRITON_APPEND_CMAKE_ARGS', '')
if current_append:
os.environ['TRITON_APPEND_CMAKE_ARGS'] = current_append + " -DTRITON_ENABLE_COVERAGE_HITEST=ON"
else:
os.environ['TRITON_APPEND_CMAKE_ARGS'] = "-DTRITON_ENABLE_COVERAGE_HITEST=ON"
else:
for key in list(os.environ.keys()):
if key.startswith('HITEST_') or key in ['HitestHome', 'lltcovRootpath']:
del os.environ[key]
for ext in self.extensions:
self.build_extension(ext)
def get_pybind11_cmake_args(self):
pybind11_sys_path = get_env_with_keys(["PYBIND11_SYSPATH"])
if pybind11_sys_path:
pybind11_include_dir = os.path.join(pybind11_sys_path, "include")
else:
pybind11_include_dir = pybind11.get_include()
return [f"-DPYBIND11_INCLUDE_DIR={pybind11_include_dir}"]
def get_proton_cmake_args(self):
cmake_args = get_thirdparty_packages([get_json_package_info()])
cmake_args += self.get_pybind11_cmake_args()
cupti_include_dir = get_env_with_keys(["TRITON_CUPTI_INCLUDE_PATH"])
if cupti_include_dir == "":
cupti_include_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "include")
cmake_args += ["-DCUPTI_INCLUDE_DIR=" + cupti_include_dir]
cupti_lib_dir = get_env_with_keys(["TRITON_CUPTI_LIB_PATH"])
if cupti_lib_dir == "":
cupti_lib_dir = os.path.join(get_base_dir(), "third_party", "nvidia", "backend", "lib", "cupti")
cmake_args += ["-DCUPTI_LIB_DIR=" + cupti_lib_dir]
roctracer_include_dir = get_env_with_keys(["ROCTRACER_INCLUDE_PATH"])
if roctracer_include_dir == "":
roctracer_include_dir = os.path.join(get_base_dir(), "third_party", "amd", "backend", "include")
cmake_args += ["-DROCTRACER_INCLUDE_DIR=" + roctracer_include_dir]
return cmake_args
def build_extension(self, ext):
lit_dir = shutil.which('lit')
ninja_dir = shutil.which('ninja')
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
thirdparty_cmake_args += self.get_pybind11_cmake_args()
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
python_include_dir = sysconfig.get_path("platinclude")
cmake_args = [
"-G", "Ninja",
"-DCMAKE_MAKE_PROGRAM=" +
ninja_dir,
"-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON",
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF",
"-DTRITON_BUILD_PYTHON_MODULE=ON", "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
"-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]),
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]),
"-DLLVM_MAJOR_VERSION_20_COMPATIBLE=ON"
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
cmake_args.extend(thirdparty_cmake_args)
cfg = get_build_type()
build_args = ["--config", cfg]
if platform.system() == "Windows":
cmake_args += [f"-DCMAKE_RUNTIME_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"]
if sys.maxsize > 2**32:
cmake_args += ["-A", "x64"]
else:
cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg]
max_jobs = os.getenv("MAX_JOBS", str(2 * os.cpu_count()))
build_args += ['-j' + max_jobs]
if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"):
cmake_args += [
"-DCMAKE_C_COMPILER=clang",
"-DCMAKE_CXX_COMPILER=clang++",
"-DCMAKE_LINKER=lld",
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
]
if check_env_flag("TRITON_BUILD_WITH_ASAN"):
cmake_args += [
"-DCMAKE_C_FLAGS=-fsanitize=address",
"-DCMAKE_CXX_FLAGS=-fsanitize=address",
]
if check_env_flag("TRITON_BUILD_WITH_CCACHE"):
cmake_args += [
"-DCMAKE_CXX_COMPILER_LAUNCHER=ccache",
]
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
cmake_args += self.get_proton_cmake_args()
else:
cmake_args += ["-DTRITON_BUILD_PROTON=OFF"]
if is_offline_build():
cmake_args += ["-DTRITON_BUILD_UT=OFF"]
ascendnpu_ir_tag = os.getenv("ASCENDNPU_IR_TAG")
if ascendnpu_ir_tag is not None:
cmake_args += [f"-DASCENDNPU_IR_TAG={ascendnpu_ir_tag}"]
cmake_args_append = os.getenv("TRITON_APPEND_CMAKE_ARGS")
if cmake_args_append is not None:
cmake_args += shlex.split(cmake_args_append)
env = os.environ.copy()
cmake_dir = get_cmake_dir()
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir)
nvidia_version_path = os.path.join(get_base_dir(), "cmake", "nvidia-toolchain-version.json")
with open(nvidia_version_path, "r") as nvidia_version_file:
NVIDIA_TOOLCHAIN_VERSION = json.load(nvidia_version_file)
def get_platform_dependent_src_path(subdir):
return lambda platform, version: (
(lambda version_major, version_minor1, version_minor2, : f"targets/{platform}/{subdir}"
if int(version_major) >= 12 and int(version_minor1) >= 5 else subdir)(*version.split('.')))
backends = [*BackendInstaller.copy(["ascend"]), *BackendInstaller.copy_externals()]
def add_link_to_backends():
for backend in backends:
if os.path.islink(backend.install_dir):
os.unlink(backend.install_dir)
if os.path.exists(backend.install_dir):
shutil.rmtree(backend.install_dir)
os.symlink(backend.backend_dir, backend.install_dir)
if backend.language_dir:
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "language", "extra"))
for x in os.listdir(backend.language_dir):
src_dir = os.path.join(backend.language_dir, x)
install_dir = os.path.join(extra_dir, x)
if os.path.islink(install_dir):
os.unlink(install_dir)
if os.path.exists(install_dir):
shutil.rmtree(install_dir)
os.symlink(src_dir, install_dir)
def add_link_to_proton():
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
proton_install_dir = os.path.join(os.path.dirname(__file__), "triton", "profiler")
if os.path.islink(proton_install_dir):
os.unlink(proton_install_dir)
if os.path.exists(proton_install_dir):
shutil.rmtree(proton_install_dir)
os.symlink(proton_dir, proton_install_dir)
def add_links():
add_link_to_backends()
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
add_link_to_proton()
class plugin_install(install):
def run(self):
add_links()
install.run(self)
class plugin_develop(develop):
def run(self):
add_links()
develop.run(self)
class plugin_bdist_wheel(bdist_wheel):
def run(self):
add_links()
bdist_wheel.run(self)
class plugin_egginfo(egg_info):
def run(self):
add_links()
egg_info.run(self)
class BuildWheel(bdist_wheel):
def run(self):
add_links()
bdist_wheel.run(self)
if is_manylinux:
file = glob.glob(os.path.join(self.dist_dir, "*-linux_*.whl"))[0]
auditwheel_cmd = [
"auditwheel",
"-v",
"repair",
"--plat",
f"manylinux_2_27_{platform.machine()}",
"--plat",
f"manylinux_2_28_{platform.machine()}",
"-w",
self.dist_dir,
file,
]
try:
subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE)
finally:
os.remove(file)
package_data = {
"triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data
for b in backends}, "triton/language/extra": sum(
(b.language_package_data for b in backends), [])
}
def get_language_extra_packages():
packages = []
for backend in backends:
if backend.language_dir is None:
continue
for dir, dirs, files in os.walk(backend.language_dir, followlinks=True):
if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir:
continue
subpackage = os.path.relpath(dir, backend.language_dir)
package = os.path.join("triton/language/extra", subpackage)
packages.append(package)
return list(packages)
def get_packages():
packages = [
"triton",
"triton/_C",
"triton/compiler",
"triton/language",
"triton/language/extra",
"triton/runtime",
"triton/backends",
"triton/tools",
"triton/extension",
"triton/extension/buffer",
"triton/extension/buffer/language",
]
packages += [f'triton/backends/{backend.name}' for backend in backends]
packages += get_language_extra_packages()
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
packages += ["triton/profiler"]
return packages
def get_entry_points():
entry_points = {}
if check_env_flag("TRITON_BUILD_PROTON", "ON"):
entry_points["console_scripts"] = [
"proton-viewer = triton.profiler.viewer:main",
"proton = triton.profiler.proton:main",
]
return entry_points
def get_git_commit_hash(length=8):
try:
cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD']
return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8'))
except Exception:
return ""
def get_default_version():
version_file = Path(__file__).parent / "version.txt"
if version_file.exists():
return version_file.read_text().strip()
return "3.2.0"
def get_version():
version = os.environ.get("TRITON_VERSION", get_default_version()) + os.environ.get(
"TRITON_WHEEL_VERSION_SUFFIX", ""
)
if not is_manylinux:
version += get_git_commit_hash()
return version
def get_package_name():
return os.environ.get("TRITON_WHEEL_NAME", "triton_ascend")
is_manylinux = check_env_flag("IS_MANYLINUX", "FALSE")
readme = os.path.join(triton_dir, "README.md")
if not os.path.exists(readme):
raise FileNotFoundError("Unable to find 'README.md'")
with open(readme, encoding="utf-8") as fdesc:
long_description = fdesc.read()
setup(
name=get_package_name(),
version=get_version(),
author="Philippe Tillet",
author_email="phil@openai.com",
description="A language and compiler for custom Deep Learning operations",
long_description=long_description,
packages=get_packages(),
entry_points=get_entry_points(),
package_data=package_data,
include_package_data=True,
ext_modules=[CMakeExtension("triton", "triton/_C/")],
cmdclass={
"build_ext": CMakeBuild,
"build_py": CMakeBuildPy,
"clean": CMakeClean,
"install": plugin_install,
"develop": plugin_develop,
"bdist_wheel": BuildWheel,
"egg_info": plugin_egginfo,
},
zip_safe=False,
keywords=["Compiler", "Deep Learning"],
url="https://gitcode.com/Ascend/triton-ascend/",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
test_suite="tests",
install_requires=[
"attrs==24.2.0",
"numpy==1.26.4",
"scipy==1.13.1",
"decorator==5.1.1",
"psutil==6.0.0",
"pytest==8.3.2",
"pytest-xdist==3.6.1",
"pyyaml"
],
extras_require={
"build": [
"cmake>=3.20",
"lit",
],
"tests": [
"autopep8",
"flake8",
"isort",
"numpy",
"pytest",
"scipy>=1.7.1",
"llnl-hatchet",
],
"tutorials": [
"matplotlib",
"pandas",
"tabulate",
],
},
)