@@ -1,7 +1,11 @@
# BiShengIR Project.
-cmake_minimum_required(VERSION 3.28.0)
+cmake_minimum_required(VERSION 3.14.0)
project(bishengir)
+set(CMAKE_CXX_STANDARD 17)
+find_package(MLIR REQUIRED CONFIG)
+include_directories(${MLIR_INCLUDE_DIRS})
+
# bishengir-target-spec-tblgen related target has ninja version requirement
if(CMAKE_GENERATOR MATCHES "Ninja")
execute_process(COMMAND ${CMAKE_MAKE_PROGRAM} --version
@@ -87,12 +91,6 @@ find_program(
# Setup variables and includes
# -------------------------------------------------------------------------------
-set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir)
-set(MLIR_MAIN_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include)
-set(MLIR_CMAKE_DIR ${MLIR_MAIN_SRC_DIR}/cmake/modules)
-set(MLIR_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
-set(MLIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
-
set(BISHENGIR_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/bishengir)
set(BISHENGIR_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(BISHENGIR_MAIN_INCLUDE_DIR ${BISHENGIR_SRC_DIR}/include)
@@ -153,11 +151,6 @@ if(BISHENGIR_ENABLE_TORCH_CONVERSIONS)
include_directories(${TORCH_MLIR_INCLUDE_DIRS})
endif()
-if(MLIR_ENABLE_BINDINGS_PYTHON)
- include(MLIRDetectPythonEnv)
- mlir_configure_python_dev_packages()
-endif()
-
if(BISHENGIR_BUILD_EXAMPLES)
set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}
CACHE STRING "Ascend CANN install path")
@@ -180,13 +173,13 @@ add_llvm_install_targets(install-bishengir-publish-products
# Check format for BISHENGIR_VERSION
foreach(part MAJOR MINOR PATCHLEVEL)
if(NOT BISHENGIR_VERSION_${part} MATCHES "^[0-9]+$")
- message(FATAL_ERROR
+ message(FATAL_ERROR
"BISHENGIR_VERSION_${part} is invalid: '${BISHENGIR_VERSION_${part}}'. "
"This part must contain only digits.")
endif()
endforeach()
if(NOT "${BISHENGIR_VERSION_SUFFIX}" MATCHES "^[a-zA-Z0-9-]*$")
- message(FATAL_ERROR
+ message(FATAL_ERROR
"BISHENGIR_VERSION_SUFFIX is invalid: '${BISHENGIR_VERSION_SUFFIX}'. "
"Suffix must only contain alphanumeric characters or hyphens.")
endif()
@@ -220,11 +213,21 @@ if (BISHENGIR_BUILD_STANDALONE_IR_ONLY)
endforeach()
# TODO: Math and Utils should be added in STANDALONE_IR_BUILD_DIALECT
add_subdirectory(bishengir/include/bishengir/Dialect/MathExt)
+ add_subdirectory(bishengir/include/bishengir/Dialect/SCF)
+ add_subdirectory(bishengir/include/bishengir/Conversion)
add_subdirectory(bishengir/lib/Dialect/HIVM/Utils)
add_subdirectory(bishengir/lib/Dialect/HACC/Transforms)
add_subdirectory(bishengir/lib/Dialect/HACC/Utils)
add_subdirectory(bishengir/lib/Dialect/Math)
add_subdirectory(bishengir/lib/Dialect/Utils)
+ add_subdirectory(bishengir/lib/Dialect/SCF)
+ add_subdirectory(bishengir/lib/Dialect/HIVM/Transforms)
+ add_subdirectory(bishengir/lib/Dialect/HIVM/Analysis)
+ add_subdirectory(bishengir/lib/Dialect/Scope/Transforms)
+ add_subdirectory(bishengir/lib/Dialect/Annotation/Transforms)
+ add_subdirectory(bishengir/lib/Dialect/Analysis)
+ add_subdirectory(bishengir/lib/Conversion/HIVMToStandard)
+ add_subdirectory(bishengir/lib/Conversion/ArithToAffine)
else()
add_subdirectory(bishengir/include/bishengir)
add_subdirectory(bishengir/lib)
@@ -1898,7 +1898,7 @@ void ConvertHIVMToStandardPass::runOnOperation() {
canonicalPatterns, canonicalPatterns.getContext());
scf::ForOp::getCanonicalizationPatterns(canonicalPatterns,
canonicalPatterns.getContext());
- if (failed(applyPatternsGreedily(module, std::move(canonicalPatterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(module, std::move(canonicalPatterns)))) {
return signalPassFailure();
}
@@ -245,7 +245,8 @@ void DimensionAnalyzerBase::processSlicingOp(T slicingOp) {
std::swap(src, res);
SmallVector<OpFoldResult> srcShape;
if (auto expandOp = src.template getDefiningOp<tensor::ExpandShapeOp>()) {
- srcShape = expandOp.getMixedOutputShape();
+ Builder b(expandOp.getContext());
+ srcShape = getMixedValues(expandOp.getStaticOutputShape(), expandOp.getOutputShape(), b);
} else {
srcShape = llvm::map_to_vector(
utils::getShape(src.getType()),
@@ -94,7 +94,7 @@ void AddFFTSToSyncBlockSetOpPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.insert<AddFFTSPattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -379,7 +379,7 @@ LogicalResult applyPatterns(MLIRContext *context, GreedyRewriteConfig config,
Operation *op, bool *changed = nullptr) {
RewritePatternSet patterns(context);
addPattern<PATTERNS...>(patterns);
- return applyPatternsGreedily(op, std::move(patterns), config, changed);
+ return applyPatternsAndFoldGreedily(op, std::move(patterns), config, changed);
}
template <typename OpType>
@@ -472,7 +472,7 @@ LogicalResult propagteAlignInfoUpToAlloc(MLIRContext *context,
populatePropagateAlignUpToRootAllocationPattern(
patterns, hivm::StrideAlignDimsAttr::name.str(),
hivm::StrideAlignValueInByteAttr::name.str());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config))) {
LDBG("propagating up failed");
return failure();
}
@@ -500,7 +500,7 @@ LogicalResult populatePropagateAlignAmongOpOperands(MLIRContext *context,
populatePropagateAlignAmongOpOperandsPatterns(patterns);
GreedyRewriteConfig config = GreedyRewriteConfig();
config.maxIterations = 10000;
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config))) {
LDBG("propagating in operation failed");
return failure();
}
@@ -585,7 +585,7 @@ void EnableStrideAlignPass::runOnOperation() {
ConversionTarget target(getContext());
patterns.add<EnableAlignAllocation<memref::AllocOp>,
EnableAlignAllocation<memref::AllocaOp>>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
LDBG("enable align allocation failed");
return signalPassFailure();
}
@@ -215,7 +215,7 @@ void AlignAllocSizePass::runOnOperation() {
populatePropagateAlignUpToRootAllocationPattern(
patterns, hivm::AllocAlignDimsAttr::name.str(),
hivm::AllocAlignValueInByteAttr::name.str());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
@@ -225,7 +225,7 @@ void AlignAllocSizePass::runOnOperation() {
// step 3: modify the alloc and do size alignment
patterns.clear();
populateAlignAllocAlignPattern(patterns);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
@@ -69,7 +69,7 @@ void AllocToAllocaPass::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateAllocToAllocaPatterns(patterns);
- if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -6,12 +6,13 @@ file(GLOB TILING ${CMAKE_CURRENT_SOURCE_DIR}/Tiling/*.cpp)
file(GLOB TILE_AND_BIND_SUB_BLOCK ${CMAKE_CURRENT_SOURCE_DIR}/TileAndBindSubBlock/*.cpp)
add_bishengir_dialect_library(BiShengIRHIVMTransforms
+ PARTIAL_SOURCES_INTENDED
${ALIGN_BUFFER}
- ${BUBBLE_UP_EXTRACT_SLICE}
+ #${BUBBLE_UP_EXTRACT_SLICE}
${INJECT_SYNC}
${GRAPH_SYNC_CONFLICT_SOLVER}
- ${TILING}
- ${TILE_AND_BIND_SUB_BLOCK}
+ #${TILING}
+ #${TILE_AND_BIND_SUB_BLOCK}
AddFFTSToSyncBlockSetOp.cpp
AllocExtraBuffer.cpp
AllocToAlloca.cpp
@@ -42,7 +43,7 @@ add_bishengir_dialect_library(BiShengIRHIVMTransforms
EnableMultiBuffer.cpp
FlattenOps.cpp
HIVMAggregatedDecomposeOp.cpp
- HIVMBubbleUpExtractSlice.cpp
+ #HIVMBubbleUpExtractSlice.cpp
HIVMDecomposeOp.cpp
HIVMInlineOTFLoadStore.cpp
HIVMLowerToLoops.cpp
@@ -54,17 +55,17 @@ add_bishengir_dialect_library(BiShengIRHIVMTransforms
InferHIVMMemScope.cpp
InitEntryKernel.cpp
InjectBlockSync.cpp
- InlineFixpipe.cpp
+ #InlineFixpipe.cpp
InlineLoadCopy.cpp
InlineOTFBroadcast.cpp
InsertFreeLockVarBeforeReturn.cpp
InsertInferSyncBlockLockNumAndInitFunc.cpp
InsertInferTaskTypeFunc.cpp
InsertInferWorkSpaceSizeFunc.cpp
- InsertInitAndFinishForDebug.cpp
- InsertLoadStoreForMixCV.cpp
- InsertNZ2NDForDebug.cpp
- InsertWorkSpaceForMixCV.cpp
+ #InsertInitAndFinishForDebug.cpp
+ #InsertLoadStoreForMixCV.cpp
+ #InsertNZ2NDForDebug.cpp
+ #InsertWorkSpaceForMixCV.cpp
LiftLowestStride.cpp
LiftZeroRank.cpp
LowerCreateSyncBlockLock.cpp
@@ -85,9 +86,9 @@ add_bishengir_dialect_library(BiShengIRHIVMTransforms
SinkOpToConsumerInLoop.cpp
SplitMixKernel.cpp
SyncBlockHoisting.cpp
- TileAndBindSubBlock.cpp
+ #TileAndBindSubBlock.cpp
TileBatchMMIntoLoop.cpp
- TileCubeVectorLoop.cpp
+ #TileCubeVectorLoop.cpp
TritonGlobalKernelArgsToHIVMOp.cpp
UnitFlagInfoBase.cpp
@@ -125,5 +126,5 @@ add_bishengir_dialect_library(BiShengIRHIVMTransforms
MLIRMemRefDialect
BiShengIRMemRefExtDialect
BiShengIRDialectUtils
- BiShengIRTransform
+ #BiShengIRTransform
)
@@ -174,7 +174,7 @@ void CloneTensorEmptyPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateCloneTensorEmptyPattern(patterns);
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
} // namespace
@@ -289,7 +289,7 @@ void ComposeCollapseExpandPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<ComposeExpandOfCollapseOpPattern>(patterns.getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createComposeCollapseExpandPass() {
@@ -147,7 +147,7 @@ void ConstantizeBufferPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<ConstantizeAllocLikeOp<memref::AllocOp>,
ConstantizeAllocLikeOp<memref::AllocaOp>>(patterns.getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createConstantizeBufferSizePass() {
@@ -55,7 +55,7 @@ struct PropagateConvertLayoutPass
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
- if (failed(applyPatternsGreedily(module, std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
signalPassFailure();
}
};
@@ -99,7 +99,7 @@ void ConvertNonContiguousReshapeToCopyPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<ConvertMaybeNonContiguousReassociativeReshapeOpToCopy>(
patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
@@ -301,7 +301,7 @@ void ConvertToHIVMOpPass::runOnOperation() {
// rewrite op within cur funcOp
RewritePatternSet patterns(ctx);
populateHIVMOpRewritingRule(patterns);
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
});
}
@@ -376,7 +376,7 @@ void EnableMultiBufferPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.insert<MultiBufferPattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -93,7 +93,7 @@ void FlattenOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<FlattenOpsRewritePattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -102,7 +102,7 @@ void HIVMAggregatedDecomposeOpPass::runOnOperation() {
return;
RewritePatternSet patterns(&getContext());
patterns.add<HIVMDecomposePattern>(&getContext(), decomposePhase);
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createHIVMAggregatedDecomposeOpPass(
@@ -107,7 +107,7 @@ public:
populateBubbleUpExtractSliceOpPatterns(patterns);
populateCSEPattern(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns, true);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config))) {
return signalPassFailure();
}
PassManager pm(funcOp->getContext());
@@ -127,7 +127,7 @@ public:
populateBubbleUpExtractSliceOpPatterns(patterns2);
populateCSEPattern(patterns2);
tensor::populateFoldTensorEmptyPatterns(patterns2, true);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns2), config))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns2), config))) {
return signalPassFailure();
}
if (failed(verifyMarkedExtractSlicesAreBubbledUp(funcOp))) {
@@ -1775,7 +1775,7 @@ void HIVMDecomposeOpPass::runOnOperation() {
auto funcCoreType = funcCoreTypeOpt.has_value() ? funcCoreTypeOpt.value()
: TFuncCoreType::AIC_OR_AIV;
patterns.add<SyncBlockOpLowering>(&getContext(), isMixModule, funcCoreType);
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createHIVMDecomposeOpPass() {
@@ -138,7 +138,7 @@ void HIVMInlineOTFLoadStore::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<UnalignedLastDimConcatStorePattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
@@ -85,7 +85,7 @@ void HIVMLowerToLoopsPass::runOnOperation() {
return;
RewritePatternSet patterns(&getContext());
patterns.add<HIVMLowerToLoopsPattern>(&getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createHIVMLowerToLoopsPass() {
@@ -76,7 +76,7 @@ void HIVMMapForallToBlocksPass::runOnOperation() {
populateForallToBlocksPatterns(patterns);
// To expand delinearizeIndexOps
affine::populateAffineExpandIndexOpsPatterns(patterns);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -269,7 +269,7 @@ void HIVMOptSinglePointOpPass::runOnOperation() {
SinglePointEltVecOp<hivm::VMinOp, arith::MinNumFOp, arith::MinSIOp,
None>>(&getContext());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
signalPassFailure();
}
@@ -592,13 +592,13 @@ void InlineFixpipe::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateInlineFixpipePatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
RewritePatternSet insertFixpipeForDevicePrintPattern(&getContext());
MLIRContext *ctx = insertFixpipeForDevicePrintPattern.getContext();
insertFixpipeForDevicePrintPattern.add<InsertFixpipeForDevicePrint>(ctx);
- if (failed(applyPatternsGreedily(getOperation(), std::move(insertFixpipeForDevicePrintPattern)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(insertFixpipeForDevicePrintPattern)))) {
signalPassFailure();
}
}
@@ -321,7 +321,7 @@ public:
void InlineLoadCopyPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<LoadCopyInlinePattern>(patterns.getContext());
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createInlineLoadCopyPass() {
@@ -155,7 +155,7 @@ public:
void InlineOTFBroadcastPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<VBrcInlinePattern>(patterns.getContext());
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createInlineOTFBroadcastPass() {
@@ -117,7 +117,7 @@ void InsertInitAndFinishForDebug::runOnOperation() {
// insert finish for every debug
RewritePatternSet patterns(context);
patterns.add<InsertFinish>(patterns.getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
} // namespace
@@ -869,7 +869,7 @@ LogicalResult applyInsertLoadBeforeSCFInitArgs(MLIRContext *context,
patterns.insert<InsertLoadBeforeSCFInitArgs<scf::ForOp>,
InsertLoadBeforeSCFInitArgs<scf::WhileOp>>(
patterns.getContext());
- return applyPatternsGreedily(funcOp, std::move(patterns));
+ return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
LogicalResult preProcessComplexControlFlow(MLIRContext *context, Operation *funcOp) {
@@ -878,7 +878,7 @@ LogicalResult preProcessComplexControlFlow(MLIRContext *context, Operation *func
patterns.insert<InsertLoadOpBetweenStoreLikeAndVectorOrCube<scf::YieldOp>>(
patterns.getContext()
);
- return applyPatternsGreedily(funcOp, std::move(patterns));
+ return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
void InsertLoadStoreForMixCVPass::runOnOperation() {
@@ -907,7 +907,7 @@ void InsertLoadStoreForMixCVPass::runOnOperation() {
if (hasCube) {
patterns.insert<DuplicateTensorExtractForCube>(patterns.getContext());
}
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -104,7 +104,7 @@ public:
void InsertNZ2NDForDebug::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<InsertNZ2NDForDebugPattern>(patterns.getContext());
- (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
} // namespace
@@ -163,7 +163,7 @@ void InsertWorkSpaceForMixCVPass::runOnOperation() {
auto funcOp = getOperation();
RewritePatternSet patterns(context);
InsertWorkSpaceForMixCVPattern(patterns);
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createInsertWorkSpaceForMixCVPass() {
@@ -463,7 +463,7 @@ void LiftLowestStridePass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLiftLowestStridePatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -89,7 +89,7 @@ void LiftZeroRankPass::runOnOperation() {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<HIVMLiftZeroRankPattern>(ctx);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -96,7 +96,7 @@ void LowerCreateSyncBlockLockPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<LowerCreateSyncBlockLock>(&getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
std::unique_ptr<Pass> mlir::hivm::createSyncBlockLockLoweringPass() {
@@ -122,7 +122,7 @@ void MarkDisableLoad::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<MarkDCacheInvalidatePattern>(patterns.getContext());
- (void)applyPatternsGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
} // namespace
@@ -246,7 +246,7 @@ void MarkMultiBufferPass::runOnOperation() {
MarkWorkspaceMultiBuffer<hivm::FixpipeOp>>(
patterns.getContext(), workspaceMultiBufferNum);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
signalPassFailure();
}
} // end anonymous namespace
@@ -19,8 +19,10 @@
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "bishengir/Dialect/HIVM/IR/HIVMImpl.h"
#include "bishengir/Dialect/HIVM/IR/HIVMInterfaces.h"
-#include "bishengir/Dialect/HIVM/Pipelines/Passes.h"
#include "bishengir/Dialect/HIVM/Transforms/Passes.h"
+#include "bishengir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -98,7 +100,16 @@ void MarkRealCoreTypePass::runOnOperation() {
// run split mix kernel pass to annotate core type attribute
PassManager pm(moduleClone.getContext());
pm.addPass(createSplitMixKernelPass());
- canonicalizationHIVMPipeline(pm);
+
+ // pm.addPass(createArithToAffineConversionPass());
+ pm.nest<func::FuncOp>().addPass(scf::createCanonicalizeIterArgPass());
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createSCFForLoopCanonicalizationPass());
+ pm.addPass(createCSEPass());
+ pm.nest<func::FuncOp>().addPass(createCanonicalizerPass());
+ pm.nest<func::FuncOp>().addPass(createHIVMOptSinglePointPass());
+ pm.nest<func::FuncOp>().addPass(createCanonicalizerPass());
+ // pm.nest<func::FuncOp>().addPass(memref::createDeadStoreEliminationPass());
if (failed(pm.run(moduleClone))) {
return signalPassFailure();
}
@@ -118,7 +118,7 @@ public:
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<NormalizeBitwiseSelectPattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
@@ -292,13 +292,13 @@ void NormalizeConvOpsPass::runOnOperation() {
RewritePatternSet patterns1(context);
populateNormalizeConvOpsPattern1(patterns1);
GreedyRewriteConfig config1 = GreedyRewriteConfig();
- (void)applyPatternsGreedily(funcOp, std::move(patterns1), config1);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns1), config1);
// Second Round
RewritePatternSet patterns2(context);
populateNormalizeConvOpsPattern2(patterns2);
GreedyRewriteConfig config2 = GreedyRewriteConfig();
- (void)applyPatternsGreedily(funcOp, std::move(patterns2), config2);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns2), config2);
}
} // namespace
@@ -275,7 +275,7 @@ void NormalizeLoopIteratorPass::runOnOperation() {
func::FuncOp funcOp = getOperation();
RewritePatternSet rewritePatterns(&getContext());
populateNormalizeLoopIneratorPattern(rewritePatterns);
- if (failed(applyPatternsGreedily(funcOp, std::move(rewritePatterns))))
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(rewritePatterns))))
signalPassFailure();
}
@@ -509,7 +509,7 @@ void NormalizeMatmulPass::runOnOperation() {
// But if it is BottomUpTraversal, the second mad will be decompose to
// 'mad + add' and lose 'mad + mad' optimization.
config.useTopDownTraversal = true;
- (void)applyPatternsGreedily(funcOp, std::move(patterns), config);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns), config);
}
} // namespace
@@ -2444,7 +2444,7 @@ void PlanMemoryPass::runOnOperation() {
if (this->memMode == MemPlanMode::LOCAL_MEM_PLAN) {
RewritePatternSet normalizeLoopIterPatterns(&getContext());
populateNormalizeLoopIneratorPattern(normalizeLoopIterPatterns);
- if (failed(applyPatternsGreedily(funcOp,
+ if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(normalizeLoopIterPatterns)))) {
return signalPassFailure();
}
@@ -2517,7 +2517,7 @@ void PlanMemoryPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateBufferAddressToAllocOp(patterns, plannedBuffer2Offsets);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
@@ -350,7 +350,7 @@ void RecognizeDeinterleaveOpPass::runOnOperation() {
patterns.add<RecognizeDeinterleaveOpForLoad>(ctx);
patterns.add<RecognizeDeinterleaveOpForCopy>(ctx);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -387,7 +387,7 @@ void ReduceRankSubviewPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateReduceRankSubviewPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -67,7 +67,7 @@ void SinkOpToConsumerInLoopPass::runOnOperation() {
patterns.add<SinkOpToConsumerInLoop<hivm::VBrcOp>,
SinkOpToConsumerInLoop<linalg::BroadcastOp>,
SinkOpToConsumerInLoop<linalg::FillOp>>(ctx);
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -210,7 +210,7 @@ void postProcessCubeFunc(func::FuncOp func) {
patterns.insert<PostCubeReplacement>(patterns.getContext());
patterns.insert<FoldEmptyInsertSlice>(patterns.getContext());
tensor::populateFoldTensorEmptyPatterns(patterns);
- if (failed(applyPatternsGreedily(func.getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(func.getOperation(), std::move(patterns)))) {
llvm::report_fatal_error("postProcessCubeFunc failed");
}
removeOpWithAttrFromFunc<bufferization::ToTensorOp>(
@@ -148,7 +148,7 @@ void SyncBlockHoistingPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<HoistCreateSyncBlockLockInIfPattern>(patterns.getContext());
patterns.add<HoistingSyncBlockPattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
@@ -605,7 +605,7 @@ static LogicalResult limitUniqueSubBlockToStore(func::FuncOp funcOp) {
patterns.add<LimitUniqueSubBlockIdToStore>(funcOp.getContext());
GreedyRewriteConfig config;
config.maxIterations = kMaxIterations;
- return applyPatternsGreedily(funcOp, std::move(patterns));
+ return applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
static scf::ForOp createSubBlockLoop(Location loc, OpBuilder &builder,
@@ -722,7 +722,7 @@ static LogicalResult tileAndSliceStore(func::FuncOp func,
func->getContext(), analyzer);
GreedyRewriteConfig config;
config.maxIterations = kMaxIterations;
- if (failed(applyPatternsGreedily(func, std::move(patterns), config))) {
+ if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
return failure();
}
@@ -844,7 +844,7 @@ TileAndBindSubBlockPass::attemptBindSubBlock(func::FuncOp func) {
RewritePatternSet patternsPost(&getContext());
patternsPost.add<mlir::hivm::detail::BubbleUpSubviewFromTiling>(
&getContext());
- if (failed(applyPatternsGreedily(newFunc, std::move(patternsPost)))) {
+ if (failed(applyPatternsAndFoldGreedily(newFunc, std::move(patternsPost)))) {
failAndRevert(newFunc);
return failure();
}
@@ -885,7 +885,7 @@ void TileAndBindSubBlockPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<CanonicalizeAllocToTensor>(&getContext());
- (void)applyPatternsGreedily(moduleOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
// Collect functions to process (can't modify while iterating)
SmallVector<func::FuncOp> functionsToProcess;
@@ -388,7 +388,7 @@ void TileBatchMMIntoLoopPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<TileBatchMM>(&getContext());
- if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
@@ -296,7 +296,7 @@ LogicalResult liftMemRefLoadsInLoop(ModuleOp module) {
MLIRContext *ctx = module.getContext();
RewritePatternSet patterns(ctx);
patterns.add<LiftToTensor>(ctx);
- return applyPatternsGreedily(module, std::move(patterns));
+ return applyPatternsAndFoldGreedily(module, std::move(patterns));
}
/// Pattern to shrink memref alloc's size.
@@ -343,7 +343,7 @@ LogicalResult shrinkAlloc(ModuleOp module) {
MLIRContext *ctx = module.getContext();
RewritePatternSet patterns(ctx);
patterns.add<ShrinkAlloc>(ctx);
- return applyPatternsGreedily(module, std::move(patterns));
+ return applyPatternsAndFoldGreedily(module, std::move(patterns));
}
/// Pattern to remove the dummy store.
@@ -709,7 +709,7 @@ public:
MLIRContext *ctx = module.getContext();
RewritePatternSet patterns(ctx);
patterns.add<RemoveDummyStore>(ctx);
- return applyPatternsGreedily(module, std::move(patterns));
+ return applyPatternsAndFoldGreedily(module, std::move(patterns));
}
};
@@ -402,7 +402,7 @@ struct CanonicalizeIterArgPass
CanonicalizeIterArgPattern<scf::WhileOp>,
RemoveDeadIterArgPattern<scf::ForOp>>(
patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
@@ -78,7 +78,7 @@ void ForToForallPass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<ForToForallRewritePattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
@@ -276,7 +276,7 @@ struct RemoveRedundantLoopInitPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<RemoveRedundantLoopInitPattern>(patterns.getContext());
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
@@ -86,7 +86,7 @@ void ExtractScopeBodyPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<ExtractOpsFromBodyPattern<ScopeOp>>(&getContext());
- if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();
}
}
@@ -166,7 +166,7 @@ void OutlineScopePass::runOnOperation() {
patterns.add<OutlineScopeOp>(&getContext());
- if (failed(applyPatternsGreedily(module, std::move(patterns)))) {
+ if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
signalPassFailure();
}
}