import os
import glob
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension
import torch_npu
from torch_npu.utils.cpp_extension import NpuExtension
PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.abspath(torch_npu.__file__))
USE_NINJA = os.getenv('USE_NINJA') == '1'
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
source_files = glob.glob(os.path.join(BASE_DIR, "csrc", "*.cpp"), recursive=True)
exts = []
ext = NpuExtension(
name="cpp_extension_base.custom_ops_lib",
sources=source_files,
extra_compile_args=[
'-I' + os.path.join(PYTORCH_NPU_INSTALL_PATH, "include/third_party/acl/inc"),
'-I' + os.path.join(PYTORCH_NPU_INSTALL_PATH, "include/third_party/op-plugin"),
'-I' + os.path.join(PYTORCH_NPU_INSTALL_PATH, "include/third_party/op-plugin/op_plugin/include"),
],
)
exts.append(ext)
setup(
name="cpp_extension_base",
version='1.0',
keywords='cpp_extension_base',
ext_modules=exts,
packages=find_packages(),
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)},
)