import functools
import hashlib
import os
import re
import shutil
import subprocess
import sysconfig
from pathlib import Path
import logging
from platform import python_version
from triton.backends.ascend.backend_register import backend_strategy_registry
import pybind11
backend_policy = None
def get_backend_func(name, *args, **kwargs):
global backend_policy
if backend_policy is None:
try:
import torch
import torch_npu
backend_policy = "torch_npu"
except ImportError:
backend_policy = "mindspore"
print("the backend policy is {}".format(backend_policy))
return backend_strategy_registry.execute_func(backend_policy, name, *args, **kwargs)
def get_logger(logger_name, logger_level_str):
'''
'''
logging_level_mapping = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
logger = logging.getLogger(logger_name)
logger.propagate = False
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.setLevel(logging_level_mapping.get(logger_level_str.upper(), "INFO"))
logger.addHandler(console_handler)
return logger
def downgrade_llir(llir):
llir = _downgrade_mem_attrs(llir)
llir = _downgrade_stacksaverestore_intrinsics(llir)
return llir
def _downgrade_mem_attrs(llir: str):
memory_pattern = r"memory\([^()]*\)"
def replace_mem_attr(m):
attrs = m[0][7:-1].split(",")
if len(attrs) == 0:
return "readnone"
loc_map = {"argmem": 1, "inaccessiblemem": 2, "other": 4}
loc_attr = 0
rw_map = {"readwrite": 3, "write": 2, "read": 1, "none": 0}
rw_attr = 0
for attr_pair in attrs:
pair = attr_pair.split(":")
assert len(pair) <= 2
if len(pair) == 1:
rw = rw_map[pair[0].strip()]
loc = loc_map["other"]
else:
rw = rw_map[pair[1].strip()]
loc_str = pair[0].strip()
if loc_str == "argmem" or loc_str == "inaccessiblemem":
loc = loc_map[loc_str]
else:
loc = loc_map["other"]
if rw > 0:
loc_attr = loc_attr | loc
rw_attr = rw_attr | rw
rev_rw_map = {0: "readnone", 1: "readonly", 2: "writeonly"}
if rw_attr in rev_rw_map:
rw_attr_str = rev_rw_map[rw_attr]
else:
rw_attr_str = ""
rev_loc_map = {
1: "argmemonly",
2: "inaccessiblememonly",
3: "inaccessiblemem_or_argmemonly",
}
if loc_attr in rev_loc_map:
loc_attr_str = rev_loc_map[loc_attr]
else:
loc_attr_str = ""
return rw_attr_str + " " + loc_attr_str
return re.sub(memory_pattern, replace_mem_attr, llir)
def _downgrade_stacksaverestore_intrinsics(llir: str):
llir = re.sub(r"llvm\.stacksave\.\w+", "llvm.stacksave", llir)
llir = re.sub(r"llvm\.stackrestore\.\w+", "llvm.stackrestore", llir)
return llir
def _get_triton_adapter_opt_path() -> str:
path = os.path.dirname(__file__)
path = os.path.join(path, "triton-adapter-opt")
return path
def _get_mlir_path(path: str, *paths) -> str:
root_path = os.getenv("MLIR_ROOT", "")
if root_path == "":
raise EnvironmentError("MLIR_ROOT is not set.")
return os.path.join(root_path, path, *paths)
def _get_llvm_path(path: str, *paths) -> str:
root_path = os.getenv("LLVM_ROOT", "")
if root_path == "":
raise EnvironmentError("LLVM_ROOT is not set.")
return os.path.join(root_path, path, *paths)
def _get_npucompiler_path() -> str:
npu_compiler_path = shutil.which("bishengir-compile")
if npu_compiler_path is None:
npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "")
if npu_compiler_root is None:
raise EnvironmentError(
"Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH."
)
npu_compiler_path = os.path.join(npu_compiler_root, "npuc")
return npu_compiler_path
def _get_bisheng_path() -> str:
bisheng_path = shutil.which("bisheng")
if bisheng_path is None:
npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "")
if npu_compiler_root is None:
raise EnvironmentError(
"Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH"
)
bisheng_path = os.path.join(npu_compiler_root, "ccec")
return bisheng_path
def _is_valid_bishengir_path(path: str) -> bool:
if not path or not isinstance(path, str):
return False
if os.path.basename(path) != "bishengir-compile":
return False
if not os.path.isfile(path) or not os.access(path, os.X_OK):
return False
return True
def _check_bishengir_api_change() -> bool:
bishengir_path = _get_npucompiler_path()
if not _is_valid_bishengir_path(bishengir_path):
print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
return False
try:
result = subprocess.run(
[bishengir_path, "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0 and 'limit-auto-multi-buffer-buffer' in result.stdout:
return True
else:
return False
except Exception as e:
print(f"ERROR: {e}")
return False
def _check_bishengir_is_regbased() -> bool:
bishengir_path = _get_npucompiler_path()
if not _is_valid_bishengir_path(bishengir_path):
print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
return False
try:
result = subprocess.run(
[bishengir_path, "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0 and 'reg-based' in result.stdout:
return True
else:
return False
except Exception as e:
print(f"ERROR: {e}")
return False
@functools.lru_cache(None)
def _get_ascend_path() -> Path:
path = os.getenv("ASCEND_HOME_PATH", "")
if path == "":
raise EnvironmentError(
"ASCEND_HOME_PATH is not set, source <ascend-toolkit>/set_env.sh first"
)
return Path(path)
def _is_ascend_sanitizer_enabled() -> bool:
return os.getenv("TRITON_ENABLE_SANITIZER", "false").lower() in ("true", "1")
def _is_debug_line_info_disabled() -> bool:
return os.getenv("TRITON_DISABLE_LINE_INFO", "true").lower() in ("true", "1")
def _is_auto_map_parallel_blocks_enabled() -> bool:
return os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "false").lower() in ("true", "1")
def _enable_unpublished_feature() -> bool:
return os.getenv("ENABLE_UNPUBLISHED_FEATURE", "false").lower() in ("true", "1")
def _enable_print_ub_bits() -> bool:
return os.getenv("ENABLE_PRINT_UB_BITS", "false").lower() in ("true", "1")
def _get_cxx():
cxx = os.environ.get("CC")
if cxx is None:
clangxx = shutil.which("clang++")
gxx = shutil.which("g++")
cxx = clangxx if clangxx is not None else gxx
if cxx is None:
raise RuntimeError("Failed to find C++ compiler")
return cxx
def _get_cxx_precompiled(header_path):
cc_cmd = []
cxx = os.environ.get("CC")
if cxx is None:
clangxx = shutil.which("clang++")
gxx = shutil.which("g++")
if clangxx is not None:
cc_cmd += [clangxx, "-include", header_path]
elif gxx is not None:
cc_cmd += [gxx]
else:
raise RuntimeError("Failed to find C++ compiler")
else:
cc_cmd += [cxx]
return cc_cmd
def _precompile_npu_hash(header_src):
import sys
cxx = _get_cxx()
py_version = sys.version
asc_path = _get_ascend_path().name
version_txt = [header_src, cxx, py_version, asc_path]
version_txt += get_backend_func("version_hash")
hash_txt = hashlib.sha256("_".join(version_txt).encode("utf-8")).hexdigest()
return hash_txt
def _precompile_npu_ext(header_path):
src_dir = os.path.dirname(header_path)
gch_path = os.path.join(src_dir, "precompiled.h.gch")
cxx = _get_cxx()
cc_cmd = [cxx, "-x", "c++-header", header_path]
cc_cmd += [f"-w"]
if hasattr(sysconfig, "get_default_scheme"):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
if scheme == "posix_local":
scheme = "posix_prefix"
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
cc_cmd += [f"-I{py_include_dir}"]
cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"]
asc_path = _get_ascend_path()
rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h")
if not os.path.exists(rt_path):
cc_cmd += [
f"-I{os.path.join(asc_path, 'pkg_inc')}",
f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}",
]
cc_cmd += [
f"-I{os.path.join(asc_path, 'include')}",
f"-I{os.path.join(asc_path, 'include/experiment')}",
f"-I{os.path.join(asc_path, 'include/experiment/msprof')}",
f"-I{pybind11.get_include()}",
]
cc_cmd += get_backend_func("get_cc_cmd", build_pch=True)
cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", gch_path]
ret = subprocess.check_call(cc_cmd)
if ret != 0:
print(f"Unable to precompile header file, ret is: {ret}")
return header_path
def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="torch", precompile=False) -> str:
suffix = sysconfig.get_config_var("EXT_SUFFIX")
src_dir = os.path.dirname(src_path)
so_path = os.path.join(src_dir, f"{obj_name}{suffix}")
if precompile:
cc_cmd = _get_cxx_precompiled(header_path)
cc_cmd += [src_path]
else:
cxx = _get_cxx()
cc_cmd = [cxx, src_path]
cc_cmd += [f"-w"]
if hasattr(sysconfig, "get_default_scheme"):
scheme = sysconfig.get_default_scheme()
else:
scheme = sysconfig._get_default_scheme()
if scheme == "posix_local":
scheme = "posix_prefix"
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
cc_cmd += [f"-I{py_include_dir}"]
cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"]
asc_path = _get_ascend_path()
if header_path is not None:
cc_cmd += [f"-I{os.path.dirname(header_path)}"]
rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h")
if not os.path.exists(rt_path):
cc_cmd += [
f"-I{os.path.join(asc_path, 'pkg_inc')}",
f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}",
]
cc_cmd += [
f"-I{os.path.join(asc_path, 'include')}",
f"-I{os.path.join(asc_path, 'include/experiment')}",
f"-I{os.path.join(asc_path, 'include/experiment/msprof')}",
f"-I{pybind11.get_include()}",
f"-L{os.path.join(asc_path, 'lib64')}",
"-lruntime",
"-lascendcl",
]
cc_cmd += get_backend_func("get_cc_cmd", build_pch=False)
cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-Winvalid-pch", "-o", so_path]
result = subprocess.run(cc_cmd, capture_output=True, text=True)
if result.returncode == 0:
return so_path
else:
if "precompiled.h.gch" in result.stderr:
return _build_npu_ext(obj_name, header_path, src_path, precompile=False)
else:
raise RuntimeError(f"Failed to compile {src_path}, error: {result.stderr}")
def _get_kernel_target(metadata: dict):
if "target" not in metadata:
raise Exception("No target provided!")
sub_target = metadata["target"].arch
assert isinstance(sub_target, str)
if sub_target.startswith("Ascend910B"):
mix_mode = metadata["mix_mode"]
if mix_mode.lower().strip("_").startswith("aiv"):
return "ascend_910b_vec", "c220-vec", "aiv"
elif mix_mode.lower().strip("_").startswith("aic"):
return "ascend_910b_cube", "c220-cube", "aic"
else:
return "ascend_910b", "c220", "mix"
elif sub_target.startswith("Ascend910"):
return "ascend_910", "c100", "mix"
else:
raise NotImplementedError(f"NPU subtarget {sub_target} not supported yet")
def _check_cxx11_abi():
return get_backend_func("cxx_abi")
def convert_sigtype_to_int(sigty: str):
MAP_SIGTYPE_TO_INT = {
"i1": 12,
"i8": 2,
"i16": 6,
"i32": 3,
"i64": 9,
"u8": 4,
"u16": 7,
"u32": 8,
"u64": 10,
"fp16": 1,
"bf16": 27,
"fp32": 0,
"fp64": 11,
"fp8e5": 35,
"fp8e4nv": 36,
}
if sigty not in MAP_SIGTYPE_TO_INT:
raise ValueError(f"Unsupported data type: {sigty}")
return MAP_SIGTYPE_TO_INT[sigty]
def convert_dtype_to_numpy(dtype):
return get_backend_func("type_convert")[dtype]
def _check_bishengir_able_save_ir() -> bool:
bishengir_path = _get_npucompiler_path()
if not _is_valid_bishengir_path(bishengir_path):
print(f"ERROR: Invalid bishengir path format: {bishengir_path}")
return False
try:
result = subprocess.run(
[bishengir_path, "--help"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode == 0 and 'save-linked-ir' in result.stdout:
return True
else:
return False
except Exception as e:
print(f"ERROR: {e}")
return False
def get_ascend_arch_from_env():
arch = os.getenv("TRITON_ASCEND_ARCH", "")
if arch == "":
return arch
valid_arch_list = [
"Ascend910B1",
"Ascend910B2",
"Ascend910B3",
"Ascend910B4",
"Ascend910_9362",
"Ascend910_9372",
"Ascend910_9381",
"Ascend910_9382",
"Ascend910_9391",
"Ascend910_9392",
"Ascend310B1",
"Ascend310B2",
"Ascend310B3",
"Ascend310B4",
"Ascend910_9579",
"Ascend910_9581",
"Ascend910_9589",
"Ascend910_9599",
]
is_valid = arch in valid_arch_list
if not is_valid:
valid_arch_str = ", ".join(valid_arch_list)
raise ValueError(f"TRITON_ASCEND_ARCH = {arch} is invalid!"
f"Candidates are [{valid_arch_str}]")
return arch
def is_ffts_supported(arch: str):
'''
Cases:
- empty str: User does not specify arch, thus it runs on 910B/910D both of which support ffts. Return True.
- Ascend310B4: 310B4 does not support ffts. Return False.
- Other arch: 910B/910D supports ffts. Return True.
'''
if arch in ["Ascend910A", "Ascend310B4"]:
return False
return True
def force_disable_ffts():
'''
'''
disable_ffts = os.getenv("TRITON_DISABLE_FFTS", "false").lower() in ("true", "1")
return disable_ffts