@@ -54,6 +54,12 @@ endif()
# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`).
option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF)
+option(TORCH_MLIR_ENABLE_LINALG "Add Linalg conversion support" ON)
+if(TORCH_MLIR_ENABLE_LINALG)
+ add_definitions(-DTORCH_MLIR_ENABLE_LINALG)
+ list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_LINALG")
+endif()
+
option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON)
if(TORCH_MLIR_ENABLE_TOSA)
add_definitions(-DTORCH_MLIR_ENABLE_TOSA)
@@ -26,6 +26,7 @@ def ConvertTorchToSCF: Pass<"convert-torch-to-scf", "func::FuncOp"> {
let constructor = "mlir::torch::createConvertTorchToSCFPass()";
}
+#ifdef TORCH_MLIR_ENABLE_LINALG
def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
let summary = "Convert recognized Torch ops to Linalg ops";
let description = [{
@@ -104,6 +105,7 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> {
}];
let constructor = "mlir::torch::createConvertTorchToLinalgPass()";
}
+#endif
def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> {
let summary = "Convert Torch ops to the Tensor dialect";
@@ -1,6 +1,8 @@
add_subdirectory(TorchOnnxToTorch)
add_subdirectory(TorchToArith)
-add_subdirectory(TorchToLinalg)
+if(TORCH_MLIR_ENABLE_LINALG)
+ add_subdirectory(TorchToLinalg)
+endif()
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToTensor)
if(TORCH_MLIR_ENABLE_TOSA)
@@ -15,12 +17,14 @@ add_subdirectory(Utils)
# TODO: Automate this with add_torch_mlir_conversion_library.
set(linked_libs TorchMLIRTorchToArith
- TorchMLIRTorchToLinalg
TorchMLIRTorchToSCF
TorchMLIRTorchToTensor
TorchMLIRTorchToTMTensor
TorchMLIRTorchConversionToMLProgram
TorchMLIRConversionUtils)
+if(TORCH_MLIR_ENABLE_LINALG)
+ list(APPEND linked_libs TorchMLIRTorchToLinalg)
+endif()
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
endif()
@@ -15,7 +15,9 @@
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
+#ifdef TORCH_MLIR_ENABLE_LINALG
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
+#endif // TORCH_MLIR_ENABLE_LINALG
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
@@ -14,7 +14,9 @@
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
+#ifdef TORCH_MLIR_ENABLE_LINALG
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
+#endif // TORCH_MLIR_ENABLE_LINALG
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h"
@@ -81,8 +83,10 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
// and those constants get somewhat obscured by TorchToArith.
pm.addNestedPass<func::FuncOp>(createConvertTorchToTMTensorPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+#ifdef TORCH_MLIR_ENABLE_LINALG
pm.addNestedPass<func::FuncOp>(createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
+#endif // TORCH_MLIR_ENABLE_LINALG
pm.addNestedPass<func::FuncOp>(createConvertTorchToSCFPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());
pm.addNestedPass<func::FuncOp>(createConvertTorchToTensorPass());