import os
import sys
import platform
import setuptools
import torch
try:
from skbuild import setup as skbuild_setup
except ImportError:
print("scikit-build is required to build from source.")
raise
arch = platform.machine().lower()
if arch in ["aarch64", "arm64"]:
arch_suffix = "aarch64"
elif arch in ["x86_64", "amd64"]:
arch_suffix = "x86_64"
else:
arch_suffix = f"{arch}"
package_version = f"1.0.0+{arch_suffix}"
def _get_torch_prefix():
return os.path.dirname(torch.__file__)
def _build_variants():
"""AscendC 算子编译范围(受 CANN 平台信息限制)。
torch_plugin 适配层 .so 始终编译全部变体,不受此影响。
"""
value = os.environ.get("RECSDK_BUILD_VERS")
if value:
return value
return "A2,A3,A5"
def _ascend_serial_build():
value = os.environ.get("RECSDK_ASCEND_SERIAL_BUILD", "ON").strip().upper()
return "OFF" if value in {"0", "OFF", "FALSE", "NO"} else "ON"
def _get_cxx11_abi():
v = getattr(torch._C, "_GLIBCXX_USE_CXX11_ABI", 0)
return 1 if v else 0
def cmake_args():
torch_root = _get_torch_prefix()
os.environ.setdefault(
"CMAKE_BUILD_PARALLEL_LEVEL",
str((os.cpu_count() or 4) // 2),
)
return [
f"-DPython3_EXECUTABLE={sys.executable}",
f"-DCMAKE_PREFIX_PATH={torch_root}",
f"-DRECSDK_BUILD_VERS={_build_variants()}",
f"-DRECSDK_ASCEND_SERIAL_BUILD={_ascend_serial_build()}",
f"-D_GLIBCXX_USE_CXX11_ABI={_get_cxx11_abi()}",
]
skbuild_setup(
name="rec_sdk_ops",
version=package_version,
description="RecSDK Custom Operations for Multiple Chips",
packages=setuptools.find_packages(where=".", include=["rec_sdk_ops", "rec_sdk_ops.*"]),
cmake_args=cmake_args(),
cmake_install_dir="rec_sdk_ops",
include_package_data=True,
zip_safe=False,
)