cmake_minimum_required(VERSION 3.10)
project(PTAExtensionOPS)
execute_process(
COMMAND python3 -c "import os;import torch; print(os.path.dirname(os.path.dirname(torch.__file__)))"
OUTPUT_VARIABLE python_site_packages_path
)
string(STRIP "${python_site_packages_path}" python_site_packages_path)
set(CMAKE_SKIP_RPATH TRUE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-conversion-null")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-deprecated-declarations")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pipe -fstack-protector-strong")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fpic -fpie -Wl,--build-id=none")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-common")
set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-fabi-version=11 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-Wl,-Bsymbolic -rdynamic -Wl,--no-undefined ${CMAKE_CXX_FLAGS}")
set(PYTORCH_INSTALL_PATH ${python_site_packages_path}/torch)
set(PYTORCH_NPU_INSTALL_PATH ${python_site_packages_path}/torch_npu)
set(LD_FLAGS_GLOBAL "-shared -rdynamic -ldl -Wl,-z,relro \
-Wl,-z,now -Wl,-z,noexecstack -Wl,--build-id=none -s")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${LD_FLAGS_GLOBAL} -fexceptions")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${LD_FLAGS_GLOBAL} \
-pie -fPIE")
link_directories(${PYTORCH_INSTALL_PATH}/lib)
link_directories(${PYTORCH_NPU_INSTALL_PATH}/lib)
link_directories($ENV{ASCEND_HOME_PATH}/lib64)
if(NOT "$ENV{ASCEND_CUSTOM_PATH}" STREQUAL "")
set(ASCEND_PATH $ENV{ASCEND_CUSTOM_PATH})
else()
set(ASCEND_PATH $ENV{ASCEND_TOOLKIT_HOME})
endif()
set(ACL_INCLUDE_DIRS ${ASCEND_PATH}/include)
include_directories(
${ACL_INCLUDE_DIRS}
${ACL_INCLUDE_DIRS}/aclnn
)
add_library(PTAExtensionOPS SHARED
./plugin/register_ops.cpp
./plugin/la.cpp
./plugin/adalayernorm.cpp
./plugin/find_op_path.cpp
./plugin/la_preprocess.cpp
./plugin/rainfusionattention.cpp
./plugin/ada_block_sparse_attention.cpp
./plugin/sparse_block_estimate.cpp
./plugin/layernorm.cpp
./plugin/block_sparse_attention.cpp
./plugin/quant_flash_attn.cpp
./plugin/quant_flash_attn_metadata.cpp)
target_compile_features(PTAExtensionOPS PRIVATE cxx_std_17)
if(DEFINED ENV{USER_ABI_VERSION})
string(STRIP "$ENV{USER_ABI_VERSION}" ABI)
if(NOT (ABI STREQUAL "0" OR ABI STREQUAL "1"))
message(FATAL_ERROR "The value of USER_ABI_VERSION must be 0 or 1, but current value is '${ABI}'")
endif()
else()
set(ABI 0)
endif()
target_compile_options(PTAExtensionOPS PRIVATE -D_GLIBCXX_USE_CXX11_ABI=${ABI})
include_directories(${PYTORCH_NPU_INSTALL_PATH}/include/third_party/acl/inc)
include_directories(${PYTORCH_NPU_INSTALL_PATH}/include/third_party/hccl/inc)
include_directories(${PYTORCH_NPU_INSTALL_PATH}/include)
include_directories(${PYTORCH_INSTALL_PATH}/include)
include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed)
include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include)
execute_process(
COMMAND which python3
OUTPUT_VARIABLE Python3_EXECUTABLE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_path('include'))"
OUTPUT_VARIABLE Python3_INCLUDE_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"
OUTPUT_VARIABLE Python3_LIBRARY_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_config_var('LDLIBRARY'))"
OUTPUT_VARIABLE Python3_LIBRARY_NAME
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(Python3_LIBRARIES ${Python3_LIBRARY_DIRS}/${Python3_LIBRARY_NAME})
if(NOT EXISTS "${Python3_LIBRARIES}")
message(STATUS "Python library ${Python3_LIBRARIES} not found, trying alternate suffix...")
string(REGEX MATCH "(\\.[^.]+)$" PYLIB_NAME_EXT "${Python3_LIBRARY_NAME}")
string(REGEX REPLACE "(\\.[^.]+)$" "" PYLIB_NAME_WE "${Python3_LIBRARY_NAME}")
if(PYLIB_NAME_EXT STREQUAL ".so")
set(Python3_LIBRARIES "${Python3_LIBRARY_DIRS}/${PYLIB_NAME_WE}.a")
elseif(PYLIB_NAME_EXT STREQUAL ".a")
set(Python3_LIBRARIES "${Python3_LIBRARY_DIRS}/${PYLIB_NAME_WE}.so")
else()
message(FATAL_ERROR "Unknown Python library extension: ${Python3_LIBRARY_NAME}")
endif()
if(NOT EXISTS "${Python3_LIBRARIES}")
message(FATAL_ERROR "Python library not found after checking .so and .a: ${Python3_LIBRARIES}")
else()
message(STATUS "Found alternate Python library: ${Python3_LIBRARIES}")
endif()
else()
message(STATUS "Found Python library: ${Python3_LIBRARIES}")
endif()
include_directories(${Python3_INCLUDE_DIRS})
link_directories(${Python3_LIBRARY_DIRS})
target_link_libraries(PTAExtensionOPS PUBLIC
c10
torch
torch_cpu
torch_npu
ascendcl
${Python3_LIBRARIES}
${ASCEND_PATH}/lib64/libascendcl.so
${ASCEND_PATH}/lib64/libnnopbase.so
${ASCEND_PATH}/lib64/libopapi.so
)