#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
#define DEBUG_TYPE "linalg-specialization"
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
(rewriter.replaceOpWithNewOp<NEWOP>( \
genericOp, \
ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
ValueRange{genericOp.getDpsInits()[0]}))
#define REPLACE_UNARY_OP(NEWOP) \
(rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
ValueRange{genericOp.getDpsInputs()[0]}, \
ValueRange{genericOp.getDpsInits()[0]}))
using namespace mlir;
using namespace mlir::linalg;
static bool areBinOpsSwapped(GenericOp genericOp) {
Block *body = genericOp.getBody();
Operation *op = &body->front();
bool swapped = false;
if (op->getOpOperand(0).get() != body->getArgument(0)) {
swapped = true;
assert(op->getOpOperand(0).get() == body->getArgument(1) &&
op->getOpOperand(1).get() == body->getArgument(0) &&
"binary op uses just one block arg");
}
return swapped;
}
namespace {
enum class IndexMatchResult {
Match = 0,
Transposed,
Mismatch
};
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
unsigned expectedPosOfRowDim,
unsigned expectedPosOfColDim) {
auto exprOfRowDim = map.getResults()[rowDimIdx];
auto exprOfColDim = map.getResults()[rowDimIdx + 1];
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
exprOfColDim.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;
auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();
if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
return IndexMatchResult::Match;
if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
return IndexMatchResult::Transposed;
return IndexMatchResult::Mismatch;
}
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]});
return namedOp;
}
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
return failure();
auto mapRange = genericOp.getIndexingMapsArray();
if (llvm::any_of(mapRange,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();
auto res = inferContractionDims(genericOp);
if (!succeeded(res))
return failure();
auto dims = *res;
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();
if (!mlir::linalg::detail::isContractionBody(
*genericOp.getBlock(), [](Operation *first, Operation *second) {
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
return true;
return false;
}))
return failure();
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
return m.getResults().size() !=
dims.batch.size() + 2 ;
}))
return failure();
auto numOfBatchDims = dims.batch.size();
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
return failure();
if (numOfBatchDims) {
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
for (unsigned i = 0; i < numOfBatchDims; ++i) {
auto expr = m.getResults()[i];
if (expr.getKind() != AffineExprKind::DimId ||
cast<AffineDimExpr>(expr).getPosition() != i)
return true;
}
return false;
}))
return failure();
}
auto a =
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
auto b =
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
auto c =
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);
if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
return r == IndexMatchResult::Mismatch;
}))
return failure();
if (c != IndexMatchResult::Match ||
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();
if (numOfBatchDims) {
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
}
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
return namedOp;
}
}
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
if (isa<arith::AddFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
return namedOp;
}
if (isa<arith::SubFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
return namedOp;
}
if (isa<arith::MulFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
return namedOp;
}
if (isa<arith::DivFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
return namedOp;
}
}
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
return failure();
}
namespace {
struct LinalgSpecializeGenericOpsPass
: public impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass> {
using impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
void runOnOperation() override;
};
}
void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns) {
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
}