#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTLINALGTOSTANDARD
#include "mlir/Conversion/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::linalg;
static MemRefType makeStridedLayoutDynamic(MemRefType type) {
return MemRefType::Builder(type).setLayout(StridedLayoutAttr::get(
type.getContext(), ShapedType::kDynamic,
SmallVector<int64_t>(type.getRank(), ShapedType::kDynamic)));
}
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
SmallVector<Type, 4> result;
result.reserve(op->getNumOperands());
for (auto type : op->getOperandTypes()) {
if (auto memrefType = dyn_cast<MemRefType>(type))
result.push_back(makeStridedLayoutDynamic(memrefType));
else
result.push_back(type);
}
return result;
}
static FailureOr<FlatSymbolRefAttr>
getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
auto fnName = linalgOp.getLibraryCallName();
if (fnName.empty())
return rewriter.notifyMatchFailure(op, "No library call defined for: ");
FlatSymbolRefAttr fnNameAttr =
SymbolRefAttr::get(rewriter.getContext(), fnName);
auto module = op->getParentOfType<ModuleOp>();
if (module.lookupSymbol(fnNameAttr.getAttr()))
return fnNameAttr;
SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
if (op->getNumResults() != 0) {
return rewriter.notifyMatchFailure(
op,
"Library call for linalg operation can be generated only for ops that "
"have void return types");
}
auto libFnType = rewriter.getFunctionType(inputTypes, {});
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(module.getBody(),
std::prev(module.getBody()->end()));
func::FuncOp funcOp = rewriter.create<func::FuncOp>(
op->getLoc(), fnNameAttr.getValue(), libFnType);
funcOp->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(op->getContext()));
funcOp.setPrivate();
return fnNameAttr;
}
static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
ValueRange operands) {
SmallVector<Value, 4> res;
res.reserve(operands.size());
for (auto op : operands) {
auto memrefType = dyn_cast<MemRefType>(op.getType());
if (!memrefType) {
res.push_back(op);
continue;
}
Value cast =
b.create<memref::CastOp>(loc, makeStridedLayoutDynamic(memrefType), op);
res.push_back(cast);
}
return res;
}
LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
LinalgOp op, PatternRewriter &rewriter) const {
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (failed(libraryCallName))
return failure();
rewriter.replaceOpWithNewOp<func::CallOp>(
op, libraryCallName->getValue(), TypeRange(),
createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
op->getOperands()));
return success();
}
void mlir::linalg::populateLinalgToStandardConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<LinalgOpToLibraryCallRewrite>(patterns.getContext());
}
namespace {
struct ConvertLinalgToStandardPass
: public impl::ConvertLinalgToStandardBase<ConvertLinalgToStandardPass> {
void runOnOperation() override;
};
}
void ConvertLinalgToStandardPass::runOnOperation() {
auto module = getOperation();
ConversionTarget target(getContext());
target.addLegalDialect<affine::AffineDialect, arith::ArithDialect,
func::FuncDialect, memref::MemRefDialect,
scf::SCFDialect>();
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
RewritePatternSet patterns(&getContext());
populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertLinalgToStandardPass() {
return std::make_unique<ConvertLinalgToStandardPass>();
}