# Include MLIR Python CMake utilities
include(AddMLIRPython)
find_package(pybind11 CONFIG REQUIRED)

add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=mfusion.")

# Define main Python package sources
declare_mlir_python_sources(MFusionPythonSources)

# Add top-level __init__.py to Python sources
declare_mlir_python_sources(MFusionPythonSources.TopLevel
  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mfusion"
  ADD_TO_PARENT MFusionPythonSources
  SOURCES
    __init__.py
    # The _mlir_libs/__init__.py is automatically generated by MLIR Python bindings, cannot be included here.
    _mlir_libs/_site_initialize_0.py
)

# Define dialect Python bindings
declare_mlir_dialect_python_bindings(
  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mfusion"
  ADD_TO_PARENT MFusionPythonSources
  TD_FILE dialects/MfuseBinding.td
  SOURCES dialects/mfuse/__init__.py
  DIALECT_NAME mfuse
)

declare_mlir_python_sources(MFusionPythonSources.Tools
  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mfusion"
  ADD_TO_PARENT MFusionPythonSources
  SOURCES
    tools/opt/__main__.py
)

# Define Python extension module
set(_mfusion_private_libs
  LLVMSupport
  MLIRParser
  MLIRIR
  MLIRSupport
  MLIRFuncDialect
  MLIRArithDialect
  MLIRTensorDialect
  MLIRShapeDialect
  MLIRQuantDialect
  MLIRSparseTensorDialect
)

# Add CAPI libraries originally provided by RegisterEverything
set(_mfusion_capi_libs
  MLIRCAPIConversion
  MLIRCAPITransforms
  MLIRCAPIRegisterEverything
)


file(GLOB PYTHON_EXT_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")

declare_mlir_python_extension(MFusionPythonExtension
  MODULE_NAME _mfusion
  SOURCES ${PYTHON_EXT_SRC}
  PRIVATE_LINK_LIBS
    ${_mfusion_private_libs}
    MFusionCAPI
  EMBED_CAPI_LINK_LIBS
    MFusionCAPI
    ${_mfusion_capi_libs}
)

set(_source_components
    MLIRPythonSources
    MLIRPythonExtension.Core
    MLIRPythonExtension.RegisterEverything
    MFusionPythonSources
    MFusionPythonExtension
)

# Add Torch-MLIR Python extensions
list(APPEND _source_components
  TorchMLIRPythonSources
  TorchMLIRPythonExtensions
)

# Aggregate CAPI library
add_mlir_python_common_capi_library(MFusionAggregateCAPI
  INSTALL_COMPONENT MFusionPythonPackage
  INSTALL_DESTINATION python_packages/mfusion/_mlir_libs
  OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/python_packages/mfusion/_mlir_libs"
  RELATIVE_INSTALL_ROOT ".."
  DECLARED_SOURCES ${_source_components}
  EMBED_LIBS
    MFusionCAPI
)

# Ensure Torch-MLIR Python bindings are built before MFusionAggregateCAPI
# These Python bindings are generated in torch-mlir build directory
# Add dependencies for all Python binding targets and extensions
foreach(target_name
    # Torch-MLIR Python bindings
    TorchMLIRPythonSources.Dialects.torch.ops_gen
    TorchMLIRPythonSources.Dialects.torch
    TorchMLIRPythonSources.Dialects
    TorchMLIRPythonSources
    TorchMLIRPythonExtensions
)
  if(TARGET ${target_name})
    add_dependencies(MFusionAggregateCAPI ${target_name})
  endif()
endforeach()

# Hide all static library symbols to avoid conflicts with other MLIR Python bindings
target_link_options(MFusionAggregateCAPI PRIVATE
  $<$<PLATFORM_ID:Linux>:LINKER:--exclude-libs,ALL>
)

# Create Python package
add_mlir_python_modules(MFusionPythonPackage
  ROOT_PREFIX "${CMAKE_BINARY_DIR}/python_packages/mfusion"
  INSTALL_PREFIX "python_packages/mfusion"
  DECLARED_SOURCES ${_source_components}
  COMMON_CAPI_LINK_LIBS
    MFusionAggregateCAPI
)