import os
import glob
import subprocess
import re
import sysconfig
from distutils.errors import CompileError
from distutils.spawn import find_executable
import torch
import torch_npu
from torch.utils.cpp_extension import BuildExtension
import torch.utils.cpp_extension as cpp_extension
from setuptools import setup, Extension, find_packages
from setuptools.command.build_ext import build_ext
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
USE_NINJA = os.getenv('USE_NINJA') == '1'
source_files = glob.glob(os.path.join(BASE_DIR, "csrc", "*.asc"), recursive=True)
def get_npu_arch():
"""Get NPU architecture version from npu-smi info."""
try:
result = subprocess.run(
["npu-smi", "info"], capture_output=True, text=True, check=True
)
output = result.stdout
chip_name = None
for line in output.split('\n'):
if not line.strip():
continue
if re.search(r'Health|NPU\s+ID|Name|Version|\+|---|=', line):
continue
match = re.search(r'^\s*\|?\s*\d+\s*\|?\s*([A-Za-z0-9]+)', line)
if match:
chip_name = match.group(1)
break
if not chip_name:
raise RuntimeError("Failed to parse chip name from npu-smi info.")
if '950' in chip_name:
return 'dav-3510'
elif '910' in chip_name:
return 'dav-2201'
else:
raise RuntimeError(f"New chip model: {chip_name}, please check the corresponding architecture: dav-xxx")
except FileNotFoundError:
raise RuntimeError("npu-smi info is not found, please ensure CANN is installed")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to execute npu-smi info: {e}")
except Exception as e:
raise RuntimeError(f"Failed get NPU architecture: {e}")
def get_dependency_paths():
python_include = sysconfig.get_config_var("INCLUDEPY")
python_lib = sysconfig.get_config_var("LIBDIR")
torch_include_paths = cpp_extension.include_paths()
torch_lib = os.path.join(os.path.dirname(torch.__file__), "lib")
torch_npu_path = os.path.dirname(torch_npu.__file__)
torch_npu_include = os.path.join(torch_npu_path, "include")
torch_npu_lib = os.path.join(torch_npu_path, "lib")
all_include_paths = list([
*torch_include_paths,
python_include,
torch_npu_include,
])
all_libs = list([
python_lib,
torch_lib,
torch_npu_lib,
])
return {
"all_includes": all_include_paths,
"all_libs": all_libs
}
class AscendBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _check_bisheng_compiler(self):
bisheng_compiler = find_executable('bisheng')
if not bisheng_compiler:
raise RuntimeError("bisheng command not found!")
def build_extension(self, ext):
self._check_bisheng_compiler()
dep_paths = get_dependency_paths()
ext_fullpath = self.get_ext_fullpath(ext.name)
os.makedirs(os.path.dirname(ext_fullpath), exist_ok=True)
use_cxx11_abi = torch._C._GLIBCXX_USE_CXX11_ABI
abi_value = "1" if use_cxx11_abi else "0"
npu_arch = get_npu_arch()
compile_cmd = [
"bisheng",
"-x", "asc",
f"--npu-arch={npu_arch}",
"-shared",
"-fPIC",
"-std=c++17",
f"-D_GLIBCXX_USE_CXX11_ABI={abi_value}",
"-ltorch_npu", "-ltorch", "-lc10",
*ext.sources,
"-o", ext_fullpath,
]
for include_dir in dep_paths["all_includes"]:
compile_cmd.append(f"-I{include_dir}")
for lib_dir in dep_paths["all_libs"]:
compile_cmd.append(f"-L{lib_dir}")
try:
self.spawn(compile_cmd)
except Exception as e:
raise CompileError(f"{str(e)}") from e
your_ext = Extension(
name="op_extension.custom_ops_lib",
sources=source_files,
language="asc",
)
setup(
name="op_extension",
version="0.1",
ext_modules=[your_ext],
packages=find_packages(),
cmdclass={"build_ext": AscendBuildExtension.with_options(use_ninja=USE_NINJA)},
)