import glob
import os
import sysconfig
from distutils.errors import CompileError
from shutil import which
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
import torch
import torch_npu
import torch.utils.cpp_extension as cpp_extension
LIBRARY_NAME = "simt_mul"
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
EXTENSIONS_DIR = os.path.join(BASE_DIR, LIBRARY_NAME, "csrc")
NPU_ARCH = os.getenv("NPU_ARCH", "dav-3510")
def get_ascend_include_dirs():
env_roots = [
os.getenv("ASCEND_HOME_PATH"),
os.getenv("ASCEND_TOOLKIT_HOME"),
os.getenv("ASCEND_OPP_PATH"),
]
candidate_roots = [root for root in env_roots if root] + [
"/usr/local/Ascend/cann",
"/usr/local/Ascend/cann-9.0.0",
]
include_dirs = []
seen = set()
for root in candidate_roots:
for include_dir in (
os.path.join(root, "include"),
os.path.join(root, "x86_64-linux", "include"),
os.path.join(root, "x86_64-linux", "asc", "include"),
):
if os.path.isdir(include_dir) and include_dir not in seen:
include_dirs.append(include_dir)
seen.add(include_dir)
aclnn_include = os.path.join(include_dir, "aclnn")
if os.path.isdir(aclnn_include) and aclnn_include not in seen:
include_dirs.append(aclnn_include)
seen.add(aclnn_include)
return include_dirs
def get_dependency_paths():
python_include = sysconfig.get_config_var("INCLUDEPY")
python_lib = sysconfig.get_config_var("LIBDIR")
torch_include_paths = cpp_extension.include_paths()
torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib")
torch_npu_path = os.path.dirname(torch_npu.__file__)
torch_npu_include = os.path.join(torch_npu_path, "include")
torch_npu_acl_include = os.path.join(
torch_npu_path, "include", "third_party", "acl", "inc"
)
torch_npu_lib = os.path.join(torch_npu_path, "lib")
include_dirs = [
*torch_include_paths,
python_include,
torch_npu_include,
torch_npu_acl_include,
*get_ascend_include_dirs(),
EXTENSIONS_DIR,
os.path.join(EXTENSIONS_DIR, "simt"),
]
library_dirs = [python_lib, torch_lib, torch_npu_lib]
return {"include_dirs": include_dirs, "library_dirs": library_dirs}
class AscendBuildExtension(build_ext):
def _check_bisheng_compiler(self):
if not which("bisheng"):
raise RuntimeError("bisheng command not found")
def build_extension(self, ext):
self._check_bisheng_compiler()
dep_paths = get_dependency_paths()
ext_fullpath = self.get_ext_fullpath(ext.name)
os.makedirs(os.path.dirname(ext_fullpath), exist_ok=True)
use_cxx11_abi = torch._C._GLIBCXX_USE_CXX11_ABI
abi_value = "1" if use_cxx11_abi else "0"
debug_mode = os.getenv("DEBUG", "0") == "1"
opt_level = os.getenv("OPT_LEVEL", "0")
opt_flag = "-O0" if debug_mode else f"-O{opt_level}"
compile_cmd = [
"bisheng",
"-x",
"asc",
"--enable-simt",
f"--npu-arch={NPU_ARCH}",
"-shared",
"-fPIC",
"-std=c++17",
opt_flag,
f"-D_GLIBCXX_USE_CXX11_ABI={abi_value}",
*ext.sources,
]
if debug_mode:
compile_cmd.append("-g")
for include_dir in dep_paths["include_dirs"]:
compile_cmd.append(f"-I{include_dir}")
for library_dir in dep_paths["library_dirs"]:
compile_cmd.append(f"-L{library_dir}")
compile_cmd.extend(
[
"-ltorch_npu",
"-ltorch_python",
"-ltorch_cpu",
"-ltorch",
"-lc10",
"-o",
ext_fullpath,
]
)
try:
self.spawn(compile_cmd)
except Exception as exc:
raise CompileError(str(exc)) from exc
def get_extensions():
sources = list(glob.glob(os.path.join(EXTENSIONS_DIR, "*.asc")))
sources += list(glob.glob(os.path.join(EXTENSIONS_DIR, "simt", "*.asc")))
return [
Extension(
name=f"{LIBRARY_NAME}._C",
sources=sources,
language="asc",
)
]
setup(
name=LIBRARY_NAME,
version="0.0.1",
packages=find_packages(),
ext_modules=get_extensions(),
install_requires=["torch", "torch_npu"],
description="Ascend SIMT override for aten::mul / aten::multiply",
cmdclass={"build_ext": AscendBuildExtension},
)