* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Dialect/Linalg/IR/LinalgExtOps.h"
#include "akg/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DECL_LINALGLOWERTEMPLATEOP
#define GEN_PASS_DEF_LINALGLOWERTEMPLATEOP
#include "akg/Dialect/Linalg/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::linalgExt;
namespace {
struct LowerTemplateOp : public OpRewritePattern<TemplateOp> {
public:
using OpRewritePattern<TemplateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TemplateOp templateOp, PatternRewriter &rewriter) const override {
auto fnSym = cast<SymbolRefAttr>(templateOp->getAttr(TemplateFuncAttrName));
auto funcOp = dyn_cast_or_null<func::FuncOp>(SymbolTable::lookupNearestSymbolFrom(templateOp, fnSym));
SmallVector<Value> castedOperands =
castOperands(templateOp.getOperands(), funcOp.getArgumentTypes(), templateOp->getLoc(), rewriter);
SmallVector<Range> loopRanges =
cast<LinalgOp>(templateOp.getOperation()).createLoopRanges(rewriter, templateOp.getLoc());
llvm::transform(loopRanges, std::back_inserter(castedOperands), [&](Range range) {
return getValueOrCreateConstantIndexOp(rewriter, templateOp->getLoc(), range.size);
});
auto newOp = rewriter.create<func::CallOp>(templateOp->getLoc(), funcOp, castedOperands);
rewriter.replaceOp(templateOp, newOp.getResults());
return success();
}
private:
SmallVector<Value> castOperands(SmallVector<Value> operands, ArrayRef<Type> dstTypes, Location loc,
PatternRewriter &rewriter) const;
};
struct LinalgLowerTemplateOpPass : public impl::LinalgLowerTemplateOpBase<LinalgLowerTemplateOpPass> {
LinalgLowerTemplateOpPass() = default;
LinalgLowerTemplateOpPass(const LinalgLowerTemplateOpPass &) = default;
LinalgLowerTemplateOpPass &operator=(const LinalgLowerTemplateOpPass &) = delete;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgTemplateOpLowerPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) {
return signalPassFailure();
}
}
};
}
SmallVector<Value> LowerTemplateOp::castOperands(SmallVector<Value> operands, ArrayRef<Type> dstTypes, Location loc,
PatternRewriter &rewriter) const {
SmallVector<Value> newOperands;
for (auto item : llvm::zip(operands, dstTypes)) {
auto oper = std::get<0>(item);
auto dstType = std::get<1>(item);
if (!isa<ShapedType>(oper.getType()) || oper.getType() == dstType) {
newOperands.push_back(oper);
continue;
}
assert(isa<MemRefType>(oper.getType()) && "currently only support memref");
auto memType = cast<MemRefType>(oper.getType());
bool allDynamicSize = llvm::all_of(memType.getShape(), mlir::ShapedType::isDynamic);
if (allDynamicSize) {
newOperands.push_back(oper);
} else {
auto dynType = memType.clone(std::vector<int64_t>(memType.getShape().size(), mlir::ShapedType::kDynamic));
auto dynOper = rewriter.create<memref::CastOp>(loc, dynType, oper);
newOperands.push_back(dynOper.getResult());
}
}
return newOperands;
}
void mlir::populateLinalgTemplateOpLowerPatterns(RewritePatternSet &patterns) {
(void)patterns.add<LowerTemplateOp>(patterns.getContext());
}
std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLinalgLowerTemplateOpPass() {
return std::make_unique<LinalgLowerTemplateOpPass>();
}