#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::linalg;
static std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> stride = getConstantIntValue(range.stride);
if (!stride || *stride != 1)
return std::nullopt;
std::optional<int64_t> offset = getConstantIntValue(range.offset);
if (!offset)
return std::nullopt;
std::optional<int64_t> size = getConstantIntValue(range.size);
if (!size)
return std::nullopt;
return (*size - *offset);
}
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> tiles,
ArrayRef<int64_t> dims) {
if (dims.size() != tiles.size() || tiles.empty())
return false;
FailureOr<ContractionDimensions> contractDims =
inferContractionDims(linalgOp);
if (failed(contractDims))
return false;
unsigned batchDimsOffset = contractDims->batch.size();
SmallVector<int64_t, 3> offsetDims{dims};
for (size_t i = 0; i < offsetDims.size(); i++)
offsetDims[i] += batchDimsOffset;
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
OpBuilder builder(tileOp);
OpBuilder::InsertionGuard guard(builder);
SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
for (auto dim : llvm::enumerate(offsetDims)) {
if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
return false;
std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
std::optional<int64_t> rangeOnDim =
getConstantRange(iterationDomain[dim.value()]);
if (!tileSize || !rangeOnDim)
return false;
if (*rangeOnDim % *tileSize != 0)
return false;
}
return true;
}
static FailureOr<PackTransposeResult>
transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
tensor::PackOp packOp, AffineMap operandMap,
ArrayRef<unsigned> blocksStartDimPos,
bool transposeOuterBlocks, bool transposeInnerBlocks) {
assert(operandMap.getNumDims() >= 4 &&
"expected at least 4D prepacked matmul");
assert(blocksStartDimPos.size() >= 2 &&
"expected starting outer and inner block positions");
unsigned outerBlockPos = operandMap.getNumResults() - 4;
unsigned innerBlockPos = operandMap.getNumResults() - 2;
bool isOuterTransposed =
operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
bool isInnerTransposed =
operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
SmallVector<int64_t> innerPerm{0, 1};
if (isInnerTransposed != transposeInnerBlocks)
innerPerm = {1, 0};
SmallVector<int64_t> outerPerm{0, 1};
if (isOuterTransposed != transposeOuterBlocks)
outerPerm = {1, 0};
SmallVector<int64_t> offsetPerms;
for (auto i : llvm::seq(0u, outerBlockPos))
offsetPerms.push_back(i);
for (auto perm : outerPerm)
offsetPerms.push_back(perm + outerBlockPos);
outerPerm = offsetPerms;
FailureOr<PackTransposeResult> packTransposedMatmul =
packTranspose(rewriter, packOp, linalgOp,
nullptr, outerPerm, innerPerm);
return packTransposedMatmul;
}
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
const ControlBlockPackMatmulFn &controlPackMatmul) {
if (linalgOp.hasPureBufferSemantics())
return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
if (!options)
return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
if (options->blockFactors.size() != 3)
return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
SmallVector<OpFoldResult> mnkTiles =
getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
if (!options->allowPadding &&
!validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
return rewriter.notifyMatchFailure(linalgOp,
"expect packing full tiles only");
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(linalgOp);
FailureOr<PackResult> packedMatmul = packMatmulGreedily(
rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
options->mnkOrder);
if (failed(packedMatmul))
return failure();
assert(packedMatmul->packOps.size() == 3 &&
"invalid number of pack ops after matmul packing");
assert(packedMatmul->unPackOps.size() == 1 &&
"invalid number of unpack ops after matmul packing");
FailureOr<ContractionDimensions> contractDims =
inferContractionDims(packedMatmul->packedLinalgOp);
if (failed(contractDims))
return failure();
auto genericOp =
dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
contractDims->m, options->lhsTransposeOuterBlocks,
options->lhsTransposeInnerBlocks);
if (failed(packedLhs))
return failure();
packedMatmul->packOps[0] = packedLhs->transposedPackOp;
packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
contractDims->k, options->rhsTransposeOuterBlocks,
options->rhsTransposeInnerBlocks);
if (failed(packedRhs))
return failure();
packedMatmul->packOps[1] = packedRhs->transposedPackOp;
packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
return packedMatmul;
}
namespace {
template <typename OpTy>
struct BlockPackMatmul : public OpRewritePattern<OpTy> {
BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(OpTy linalgOp,
PatternRewriter &rewriter) const override {
FailureOr<PackResult> packedMatmul =
blockPackMatmul(rewriter, linalgOp, controlFn);
if (failed(packedMatmul))
return failure();
return success();
}
private:
ControlBlockPackMatmulFn controlFn;
};
template <>
struct BlockPackMatmul<linalg::GenericOp>
: public OpRewritePattern<linalg::GenericOp> {
BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<linalg::GenericOp>(context, benefit),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
if (!linalg::isaContractionOpInterface(linalgOp)) {
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
}
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, linalgOp.getContext());
};
AffineExpr i, j, k;
bindDims(linalgOp->getContext(), i, j, k);
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
maps == infer({{k, i}, {k, j}, {i, j}}) ||
maps == infer({{i, k}, {j, k}, {i, j}}))) {
return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
}
FailureOr<PackResult> packedMatmul =
blockPackMatmul(rewriter, linalgOp, controlFn);
if (failed(packedMatmul))
return failure();
return success();
}
private:
ControlBlockPackMatmulFn controlFn;
};
struct LinalgBlockPackMatmul
: public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(&getContext());
ControlBlockPackMatmulFn controlFn =
[&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
BlockPackMatmulOptions options;
options.blockFactors = SmallVector<int64_t>{*blockFactors};
options.allowPadding = allowPadding;
options.mnkPaddedSizesNextMultipleOf =
SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
if (!mnkOrder.empty())
options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
return options;
};
linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
}
void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
BlockPackMatmul<linalg::BatchMatmulOp>,
BlockPackMatmul<linalg::MatmulTransposeAOp>,
BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
BlockPackMatmul<linalg::MatmulTransposeBOp>,
BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
patterns.getContext(), controlFn);
}