* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Dialect/Linalg/IR/LinalgExtOps.h"
#include "akg/Dialect/Linalg/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DECL_LINALGELEMENTWISEFUSIONEXT
#define GEN_PASS_DEF_LINALGELEMENTWISEFUSIONEXT
#include "akg/Dialect/Linalg/Passes.h.inc"
}
using namespace llvm;
using namespace mlir;
using namespace mlir::linalg;
static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,
ReassociationIndicesRef rangeReassociation) {
assert(indexingMap.isProjectedPermutation() &&
"expected projected permutation");
ReassociationIndices domainReassociation = llvm::to_vector<4>(
llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition();
}));
return domainReassociation;
}
static bool isDimSequencePreservedV2(AffineMap indexingMap,
ReassociationIndicesRef dimSequence) {
assert(!dimSequence.empty() &&
"expected non-empty list for dimension sequence");
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
llvm::SmallDenseSet<unsigned, 4> sequenceElements;
sequenceElements.insert(dimSequence.begin(), dimSequence.end());
unsigned dimSequenceStart = dimSequence[0];
for (const auto &expr : enumerate(indexingMap.getResults())) {
unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition();
if (dimInMapStart == dimSequenceStart) {
if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
return false;
for (const auto &dimInSequence : enumerate(dimSequence)) {
unsigned dimInMap =
cast<AffineDimExpr>(
indexingMap.getResult(expr.index() + dimInSequence.index()))
.getPosition();
if (dimInMap != dimInSequence.value())
return false;
}
return true;
}
if (sequenceElements.count(dimInMapStart))
return false;
}
return true;
}
static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
ArrayRef<ReassociationIndices> reassociation) {
if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
return {};
if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
return map.isProjectedPermutation();
})) {
return {};
}
SmallVector<unsigned> reductionDims;
genericOp.getReductionDims(reductionDims);
llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand);
auto iteratorTypes = genericOp.getIteratorTypesArray();
SmallVector<ReassociationIndices> iterationSpaceReassociation;
for (ReassociationIndicesRef foldedRangeDims : reassociation) {
assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
if (foldedRangeDims.size() == 1)
continue;
ReassociationIndices foldedIterationSpaceDims =
getDomainReassociation(indexingMap, foldedRangeDims);
if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
return processedIterationDims.count(dim);
}))
continue;
utils::IteratorType startIteratorType =
iteratorTypes[foldedIterationSpaceDims[0]];
if (!isParallelIterator(startIteratorType) &&
!isReductionIterator(startIteratorType))
continue;
if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
return iteratorTypes[dim] != startIteratorType;
}))
continue;
if (isReductionIterator(startIteratorType)) {
bool isContiguous = false;
for (const auto &startDim : llvm::enumerate(reductionDims)) {
if (startDim.value() != foldedIterationSpaceDims[0])
continue;
if (startDim.index() + foldedIterationSpaceDims.size() >
reductionDims.size())
break;
isContiguous = true;
for (const auto &foldedDim :
llvm::enumerate(foldedIterationSpaceDims)) {
if (reductionDims[foldedDim.index() + startDim.index()] !=
foldedDim.value()) {
isContiguous = false;
break;
}
}
break;
}
if (!isContiguous)
continue;
}
if (llvm::any_of(genericOp.getIndexingMapsArray(),
[&](AffineMap indexingMap) {
return !isDimSequencePreservedV2(indexingMap,
foldedIterationSpaceDims);
}))
continue;
processedIterationDims.insert(foldedIterationSpaceDims.begin(),
foldedIterationSpaceDims.end());
iterationSpaceReassociation.emplace_back(
std::move(foldedIterationSpaceDims));
}
return iterationSpaceReassociation;
}
class FoldReshapeWithGenericOpByCollapsing
: public OpRewritePattern<tensor::CollapseShapeOp> {
public:
FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
: OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
auto producerResult = dyn_cast<OpResult>(collapseOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(collapseOp,
"source not produced by an operation");
}
auto genericOp = dyn_cast<GenericOp>(producerResult.getOwner());
if (!genericOp) {
return rewriter.notifyMatchFailure(collapseOp,
"producer not a generic op");
}
auto fuseInitOperand =
genericOp.getDpsInitOperand(producerResult.getResultNumber());
SmallVector<ReassociationIndices> collapsableIterationDims =
getCollapsableIterationSpaceDims(genericOp, fuseInitOperand,
collapseOp.getReassociationIndices());
if (collapsableIterationDims.empty()) {
return rewriter.notifyMatchFailure(collapseOp,
"index map cannot be collapsed");
}
if (!controlFoldingReshapes(fuseInitOperand)) {
return rewriter.notifyMatchFailure(collapseOp, "control function failed");
}
return success();
}
private:
ControlFusionFn controlFoldingReshapes;
};
namespace {
static bool checkFusedOpDominateAllProducerUsers(Operation *fusedOp, Operation *producer, DominanceInfo &domInfo) {
for (auto res : producer->getResults()) {
for (auto user : res.getUsers()) {
if (!domInfo.properlyDominates(fusedOp, user)) {
return false;
}
}
}
return true;
}
static bool CheckIfMatchDominateInSimplePattern0(Operation *fusedOp, Operation *op, DominanceInfo &domInfo) {
if (op->getNumResults() > 1) {
return false;
}
if (op->getNumResults() == 1 && !op->getResults()[0].hasOneUse()) {
return false;
}
if (isa<func::ReturnOp>(op) || domInfo.properlyDominates(fusedOp, op)) {
return true;
}
Operation *userOp = *(op->getResults()[0].getUsers().begin());
return CheckIfMatchDominateInSimplePattern0(fusedOp, userOp, domInfo);
}
static bool TryingtToPreserveOrderInSimplePattern0(Operation *fusedOp, Operation *op, DominanceInfo &domInfo) {
Operation *userOp = *(op->getResults()[0].getUsers().begin());
if (isa<func::ReturnOp>(userOp) || domInfo.properlyDominates(fusedOp, userOp)) {
op->moveBefore(userOp);
return true;
}
if (TryingtToPreserveOrderInSimplePattern0(fusedOp, userOp, domInfo)) {
op->moveBefore(userOp);
return true;
}
return false;
}
class FuseElementwiseOpsExt : public OpRewritePattern<GenericOp> {
public:
FuseElementwiseOpsExt(MLIRContext *context, ControlFusionFn fun, DominanceInfo &domInfo, PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(std::move(fun)), domInfo(domInfo) {}
LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override {
for (OpOperand &opOperand : genericOp->getOpOperands()) {
if (!areElementwiseOpsFusable(&opOperand)) {
continue;
}
if (!controlFn(&opOperand)) {
continue;
}
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult = fuseElementwiseOps(rewriter, &opOperand);
if (succeeded(fusionResult)) {
Operation *fusedOp = fusionResult->fusedOp;
auto replacements = fusedOp->getResults().take_back(genericOp.getNumResults());
rewriter.replaceOp(genericOp, replacements);
Operation *producer = opOperand.get().getDefiningOp();
if (CheckIfMatchDominateInSimplePattern0(fusedOp, producer, domInfo)) {
(void)TryingtToPreserveOrderInSimplePattern0(fusedOp, producer, domInfo);
}
if (!checkFusedOpDominateAllProducerUsers(fusedOp, producer, domInfo)) {
return success();
}
replacements = fusedOp->getResults().take_front(producer->getNumResults());
rewriter.replaceOp(producer, replacements);
return success();
}
}
return failure();
}
private:
ControlFusionFn controlFn;
DominanceInfo &domInfo;
};
struct LinalgElementwiseFusionExtPass : public impl::LinalgElementwiseFusionExtBase<LinalgElementwiseFusionExtPass> {
LinalgElementwiseFusionExtPass() : LinalgElementwiseFusionExtBase() {
controlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
return producer != nullptr;
};
}
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
RewritePatternSet patterns(context);
DominanceInfo &domInfo = getAnalysis<DominanceInfo>();
;
(void)patterns.add<FuseElementwiseOpsExt>(context, controlFn, domInfo);
populateEraseUnusedOperandsAndResultsPatterns(patterns);
populateFoldReshapeOpsByExpansionPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
RewritePatternSet patterns1(context);
patterns1.add<FoldReshapeWithGenericOpByCollapsing>(patterns1.getContext(), controlFn);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns1), grc);
}
private:
ControlFusionFn controlFn;
};
}
std::unique_ptr<mlir::Pass> mlir::createLinalgElementwiseFusionExtPass() {
return std::make_unique<LinalgElementwiseFusionExtPass>();
}