import os

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension

from torch_npu.utils.cpp_extension import NpuExtension
from torch_npu.testing.common_utils import set_npu_device

set_npu_device()

CXX_FLAGS = ['-g']

USE_NINJA = os.getenv('USE_NINJA') == '1'
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
SHIM_SOURCE = os.path.join(
    REPO_ROOT, "torch_npu", "csrc", "inductor", "aoti_torch", "shim_npu.cpp"
)

ext_modules = [
    NpuExtension(
        'torch_test_cpp_extension.npu', ['extension.cpp'],
        extra_compile_args=CXX_FLAGS),
    NpuExtension(
        'torch_test_cpp_extension.npu_from_blob', ['test_from_blob.cpp'],
        extra_compile_args=CXX_FLAGS),
    NpuExtension(
        'torch_test_cpp_extension.npu_external_stream', ['external_stream_test.cpp'],
        extra_compile_args=CXX_FLAGS),
    NpuExtension(
        'torch_test_cpp_extension.stable_libtorch', ['test_stable_libtorch.cpp'],
        extra_compile_args=CXX_FLAGS),
    NpuExtension(
        'torch_test_cpp_extension.npu_aoti_shim',
        ['npu_aoti_shim_extension.cpp', SHIM_SOURCE],
        include_dirs=[REPO_ROOT],
        extra_compile_args=CXX_FLAGS),
]

setup(
    name='torch_test_cpp_extension',
    packages=['torch_test_cpp_extension'],
    ext_modules=ext_modules,
    include_dirs='self_compiler_include_dirs_test',
    cmdclass={'build_ext': BuildExtension.with_options(use_ninja=USE_NINJA)})