#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
}
#define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir;
using namespace mlir::linalg;
namespace {
struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasPureTensorSemantics())
return failure();
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
auto outputOperands = genericOp.getDpsInitsMutable();
SetVector<OpOperand *> candidates;
for (OpOperand &op : outputOperands) {
if (genericOp.getMatchingBlockArgument(&op).use_empty())
continue;
candidates.insert(&op);
}
if (candidates.empty())
return failure();
int64_t origNumInput = genericOp.getNumDpsInputs();
SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
newIndexingMaps.append(indexingMaps.begin(),
std::next(indexingMaps.begin(), origNumInput));
for (OpOperand *op : candidates) {
newInputOperands.push_back(op->get());
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
}
newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
indexingMaps.end());
Location loc = genericOp.getLoc();
SmallVector<Value> newOutputOperands =
llvm::to_vector(genericOp.getDpsInits());
for (OpOperand *op : candidates) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(op->get());
auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
auto empty = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType);
unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
}
auto newOp = rewriter.create<GenericOp>(
loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
newIndexingMaps, genericOp.getIteratorTypesArray(),
nullptr, linalg::getPrunedAttributeList(genericOp));
OpBuilder::InsertionGuard guard(rewriter);
Region ®ion = newOp.getRegion();
Block *block = rewriter.createBlock(®ion);
IRMapping mapper;
for (auto bbarg : genericOp.getRegionInputArgs())
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
for (OpOperand *op : candidates) {
BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
}
for (OpOperand &op : outputOperands) {
BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
if (candidates.count(&op))
block->addArgument(bbarg.getType(), loc);
else
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
}
for (auto &op : genericOp.getBody()->getOperations()) {
rewriter.clone(op, mapper);
}
rewriter.replaceOp(genericOp, newOp.getResults());
return success();
}
};
}
static void
replaceUnitDimIndexOps(GenericOp genericOp,
const llvm::SmallDenseSet<unsigned> &unitDims,
RewriterBase &rewriter) {
for (IndexOp indexOp :
llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(indexOp);
if (unitDims.count(indexOp.getDim()) != 0) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
} else {
unsigned droppedDims = llvm::count_if(
unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
if (droppedDims != 0)
rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
indexOp.getDim() - droppedDims);
}
}
}
static Value
expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest,
ArrayRef<ReassociationIndices> reassociation,
ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
auto origResultType = cast<RankedTensorType>(origDest.getType());
if (rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
unsigned rank = origResultType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, origDest);
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
return rewriter.createOrFold<tensor::InsertSliceOp>(
loc, result, origDest, offsets, sizes, strides);
}
assert(rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
return rewriter
.create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation)
.getResult();
}
static Value collapseValue(
RewriterBase &rewriter, Location loc, Value operand,
ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
if (rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
targetShape);
assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
return *rankReducingExtract;
}
assert(
rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
MemRefLayoutAttrInterface layout;
auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
layout, memrefType.getMemorySpace());
return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
reassociation);
}
if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
if (rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
targetShape);
assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
return *rankReducingExtract;
}
assert(
rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape &&
"unknown rank reduction strategy");
auto targetType =
RankedTensorType::get(targetShape, tensorType.getElementType());
return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand,
reassociation);
}
llvm_unreachable("unsupported operand type");
}
struct UnitExtentReplacementInfo {
AffineMap indexMap;
SmallVector<ReassociationIndices> reassociation;
SmallVector<int64_t> targetShape;
};
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
ArrayRef<AffineExpr> dimReplacements) {
UnitExtentReplacementInfo info;
ReassociationIndices reassociationGroup;
SmallVector<AffineExpr> newIndexExprs;
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
auto isUnitDim = [&](unsigned dim) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
unsigned oldPosition = dimExpr.getPosition();
return !oldDimsToNewDimsMap.count(oldPosition) &&
(operandShape[dim] == 1);
}
if (operandShape[dim] == 1) {
auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
return constAffineExpr && constAffineExpr.getValue() == 0;
}
return false;
};
unsigned dim = 0;
while (dim < operandShape.size() && isUnitDim(dim))
reassociationGroup.push_back(dim++);
while (dim < operandShape.size()) {
assert(!isUnitDim(dim) && "expected non unit-extent");
reassociationGroup.push_back(dim);
AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
newIndexExprs.push_back(newExpr);
info.targetShape.push_back(operandShape[dim]);
++dim;
while (dim < operandShape.size() && isUnitDim(dim)) {
reassociationGroup.push_back(dim++);
}
info.reassociation.push_back(reassociationGroup);
reassociationGroup.clear();
}
info.indexMap =
AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
newIndexExprs, context);
return info;
}
LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
const ControlDropUnitDims &options) {
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
if (indexingMaps.empty())
return failure();
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
if (!invertedMap) {
return rewriter.notifyMatchFailure(genericOp,
"invalid indexing maps for operation");
}
SmallVector<int64_t> dims = genericOp.getStaticShape();
SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
if (allowedUnitDims.empty()) {
return rewriter.notifyMatchFailure(
genericOp, "control function returns no allowed unit dims to prune");
}
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
for (const auto &expr : enumerate(invertedMap.getResults())) {
if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
if (dims[dimExpr.getPosition()] == 1 &&
unitDimsFilter.count(expr.index()))
unitDims.insert(expr.index());
}
}
SmallVector<utils::IteratorType> newIteratorTypes;
llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
SmallVector<AffineExpr> dimReplacements;
unsigned newDims = 0;
for (auto [index, attr] :
llvm::enumerate(genericOp.getIteratorTypesArray())) {
if (unitDims.count(index)) {
dimReplacements.push_back(
getAffineConstantExpr(0, rewriter.getContext()));
} else {
newIteratorTypes.push_back(attr);
oldDimToNewDimMap[index] = newDims;
dimReplacements.push_back(
getAffineDimExpr(newDims, rewriter.getContext()));
newDims++;
}
}
SmallVector<AffineMap> newIndexingMaps;
SmallVector<SmallVector<ReassociationIndices>> reassociations;
SmallVector<SmallVector<int64_t>> targetShapes;
SmallVector<bool> collapsed;
auto hasCollapsibleType = [](OpOperand &operand) {
Type operandType = operand.get().getType();
if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
return memrefOperandType.getLayout().isIdentity();
}
if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
return tensorOperandType.getEncoding() == nullptr;
}
return false;
};
for (OpOperand &opOperand : genericOp->getOpOperands()) {
auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
if (!hasCollapsibleType(opOperand)) {
AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
newIndexingMaps.push_back(newIndexingMap);
targetShapes.push_back(llvm::to_vector(shape));
collapsed.push_back(false);
reassociations.push_back({});
continue;
}
auto replacementInfo = dropUnitExtentFromOperandMetadata(
rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
dimReplacements);
reassociations.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
targetShapes.push_back(replacementInfo.targetShape);
collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
indexingMap.getNumResults()));
}
if (newIndexingMaps == indexingMaps ||
!inversePermutation(concatAffineMaps(newIndexingMaps)))
return failure();
Location loc = genericOp.getLoc();
SmallVector<Value> newOperands;
for (OpOperand &opOperand : genericOp->getOpOperands()) {
int64_t idx = opOperand.getOperandNumber();
if (!collapsed[idx]) {
newOperands.push_back(opOperand.get());
continue;
}
newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
targetShapes[idx], reassociations[idx],
options.rankReductionStrategy));
}
ArrayRef<Value> newInputs =
ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
ArrayRef<Value> newOutputs =
ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
SmallVector<Type> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
resultTypes.push_back(newOutputs[i].getType());
GenericOp replacementOp =
rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs,
newIndexingMaps, newIteratorTypes);
rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
replacementOp.getRegion().begin());
replaceUnitDimIndexOps(replacementOp, unitDims, rewriter);
SmallVector<Value> resultReplacements;
for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) {
unsigned opOperandIndex = index + replacementOp.getNumDpsInputs();
Value origDest = genericOp.getDpsInitOperand(index)->get();
if (!collapsed[opOperandIndex]) {
resultReplacements.push_back(result);
continue;
}
Value expandedValue = expandValue(rewriter, loc, result, origDest,
reassociations[opOperandIndex],
options.rankReductionStrategy);
resultReplacements.push_back(expandedValue);
}
rewriter.replaceOp(genericOp, resultReplacements);
return success();
}
namespace {
struct DropUnitDims : public OpRewritePattern<GenericOp> {
DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), options(std::move(options)) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
return dropUnitDims(rewriter, genericOp, options);
}
private:
ControlDropUnitDims options;
};
}
namespace {
struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), options(std::move(options)) {}
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
if (allowedUnitDims.empty()) {
return rewriter.notifyMatchFailure(
padOp, "control function returns no allowed unit dims to prune");
}
if (padOp.getSourceType().getEncoding()) {
return rewriter.notifyMatchFailure(
padOp, "cannot collapse dims of tensor with encoding");
}
Value paddingVal = padOp.getConstantPaddingValue();
if (!paddingVal) {
return rewriter.notifyMatchFailure(
padOp, "unimplemented: non-constant padding value");
}
ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
int64_t padRank = sourceShape.size();
auto isStaticZero = [](OpFoldResult f) {
std::optional<int64_t> maybeInt = getConstantIntValue(f);
return maybeInt && *maybeInt == 0;
};
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
SmallVector<int64_t> newShape;
SmallVector<OpFoldResult> newLowPad;
SmallVector<OpFoldResult> newHighPad;
for (const auto [dim, size, low, high] :
zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
isStaticZero(high)) {
unitDims.insert(dim);
} else {
newShape.push_back(size);
newLowPad.push_back(low);
newHighPad.push_back(high);
}
}
if (unitDims.empty()) {
return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
}
ReassociationIndices reassociationGroup;
SmallVector<ReassociationIndices> reassociationMap;
int64_t dim = 0;
while (dim < padRank && unitDims.contains(dim))
reassociationGroup.push_back(dim++);
while (dim < padRank) {
assert(!unitDims.contains(dim) && "expected non unit-extent");
reassociationGroup.push_back(dim);
dim++;
while (dim < padRank && unitDims.contains(dim))
reassociationGroup.push_back(dim++);
reassociationMap.push_back(reassociationGroup);
reassociationGroup.clear();
}
Value collapsedSource =
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
reassociationMap, options.rankReductionStrategy);
auto newPadOp = rewriter.create<tensor::PadOp>(
padOp.getLoc(), Type(), collapsedSource, newLowPad,
newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
SmallVector<OpFoldResult> expandedSizes;
int64_t numUnitDims = 0;
for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
if (unitDims.contains(dim)) {
expandedSizes.push_back(rewriter.getIndexAttr(1));
numUnitDims++;
continue;
}
expandedSizes.push_back(tensor::getMixedSize(
rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
}
dest = rewriter.create<tensor::EmptyOp>(
padOp.getLoc(), expandedSizes,
padOp.getResultType().getElementType());
}
Value expandedValue =
expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
reassociationMap, options.rankReductionStrategy);
rewriter.replaceOp(padOp, expandedValue);
return success();
}
private:
ControlDropUnitDims options;
};
}
namespace {
struct RankReducedExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = sliceOp.getType();
SmallVector<OpFoldResult> targetShape;
for (auto size : resultType.getShape())
targetShape.push_back(rewriter.getIndexAttr(size));
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
strides));
Location loc = sliceOp.getLoc();
Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSlice, *reassociation);
return success();
}
};
template <typename InsertOpTy>
struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
using OpRewritePattern<InsertOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertSliceOp.getSourceType();
SmallVector<OpFoldResult> targetShape;
for (auto size : sourceType.getShape())
targetShape.push_back(rewriter.getIndexAttr(size));
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
Location loc = insertSliceOp.getLoc();
tensor::CollapseShapeOp reshapedSource;
{
OpBuilder::InsertionGuard g(rewriter);
if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
loc, insertSliceOp.getSource(), *reassociation);
}
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, reshapedSource, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
return success();
}
};
}
static void
populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
ControlDropUnitDims &options) {
auto *context = patterns.getContext();
patterns.add<DropUnitDims>(context, options);
patterns.add<DropPadUnitDims>(context, options);
patterns.add<RankReducedExtractSliceOp,
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
static void
populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
ControlDropUnitDims &options) {
auto *context = patterns.getContext();
options.rankReductionStrategy =
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
patterns.add<DropUnitDims>(context, options);
patterns.add<DropPadUnitDims>(context, options);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::populateFoldTensorEmptyPatterns(patterns);
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, linalg::ControlDropUnitDims &options) {
if (options.rankReductionStrategy ==
linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
populateFoldUnitExtentDimsViaSlicesPatterns(patterns, options);
} else if (options.rankReductionStrategy ==
linalg::ControlDropUnitDims::RankReductionStrategy::
ReassociativeReshape) {
populateFoldUnitExtentDimsViaReshapesPatterns(patterns, options);
}
}
void mlir::linalg::populateMoveInitOperandsToInputPattern(
RewritePatternSet &patterns) {
patterns.add<MoveInitOperandsToInput>(patterns.getContext());
}
namespace {
struct LinalgFoldUnitExtentDimsPass
: public impl::LinalgFoldUnitExtentDimsPassBase<
LinalgFoldUnitExtentDimsPass> {
using impl::LinalgFoldUnitExtentDimsPassBase<
LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
ControlDropUnitDims options;
if (useRankReducingSlices) {
options.rankReductionStrategy = linalg::ControlDropUnitDims::
RankReductionStrategy::ExtractInsertSlice;
}
linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
populateMoveInitOperandsToInputPattern(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
};
}
namespace {
static SmallVector<ReassociationIndices>
getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
bool lastDim = pos == rank - 1;
if (rank > 2) {
for (int64_t i = 0; i < rank - 1; i++) {
if (i == pos || (lastDim && i == pos - 1))
reassociation[i] = ReassociationIndices{i, i + 1};
else if (i < pos)
reassociation[i] = ReassociationIndices{i};
else
reassociation[i] = ReassociationIndices{i + 1};
}
}
return reassociation;
}
static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
int64_t pos) {
if (pos < 0)
return val;
auto valType = cast<ShapedType>(val.getType());
SmallVector<int64_t> collapsedShape(valType.getShape());
collapsedShape.erase(collapsedShape.begin() + pos);
return collapseValue(
rewriter, val.getLoc(), val, collapsedShape,
getReassociationForReshapeAtDim(valType.getRank(), pos),
ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
}
template <typename FromOpTy, typename ToOpTy>
struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
using OpRewritePattern<FromOpTy>::OpRewritePattern;
SmallVector<Value>
collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
ArrayRef<int64_t> operandCollapseDims) const {
assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
"expected 3 operands and dims");
return llvm::map_to_vector(
llvm::zip(operands, operandCollapseDims), [&](auto pair) {
return collapseSingletonDimAt(rewriter, std::get<0>(pair),
std::get<1>(pair));
});
}
Value expandResult(PatternRewriter &rewriter, Value result,
RankedTensorType expandedType, int64_t dim) const {
return rewriter.create<tensor::ExpandShapeOp>(
result.getLoc(), expandedType, result,
getReassociationForReshapeAtDim(expandedType.getRank(), dim));
}
LogicalResult matchAndRewrite(FromOpTy contractionOp,
PatternRewriter &rewriter) const override {
auto loc = contractionOp.getLoc();
auto inputs = contractionOp.getDpsInputs();
auto inits = contractionOp.getDpsInits();
if (inputs.size() != 2 || inits.size() != 1)
return rewriter.notifyMatchFailure(contractionOp,
"expected 2 inputs and 1 init");
auto lhs = inputs[0];
auto rhs = inputs[1];
auto init = inits[0];
SmallVector<Value> operands{lhs, rhs, init};
SmallVector<int64_t> operandUnitDims;
if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
return rewriter.notifyMatchFailure(contractionOp,
"no reducable dims found");
SmallVector<Value> collapsedOperands =
collapseOperands(rewriter, operands, operandUnitDims);
Value collapsedLhs = collapsedOperands[0];
Value collapsedRhs = collapsedOperands[1];
Value collapsedInit = collapsedOperands[2];
SmallVector<Type, 1> collapsedResultTy;
if (isa<RankedTensorType>(collapsedInit.getType()))
collapsedResultTy.push_back(collapsedInit.getType());
auto collapsedOp = rewriter.create<ToOpTy>(
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
ValueRange{collapsedInit});
for (auto attr : contractionOp->getAttrs()) {
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
continue;
collapsedOp->setAttr(attr.getName(), attr.getValue());
}
auto results = contractionOp.getResults();
assert(results.size() < 2 && "expected at most one result");
if (results.empty()) {
rewriter.replaceOp(contractionOp, collapsedOp);
} else {
rewriter.replaceOp(
contractionOp,
expandResult(rewriter, collapsedOp.getResultTensors()[0],
cast<RankedTensorType>(results[0].getType()),
operandUnitDims[2]));
}
return success();
}
virtual LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
};
template <typename FromOpTy, typename ToOpTy>
struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
FailureOr<ContractionDimensions> maybeContractionDims =
inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
ContractionDimensions contractionDims = maybeContractionDims.value();
if (contractionDims.batch.size() != 1)
return failure();
auto batchDim = contractionDims.batch[0];
SmallVector<std::pair<Value, unsigned>, 3> bOperands;
op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] != 1;
})) {
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
}
operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
std::get<1>(bOperands[1]),
std::get<1>(bOperands[2])};
return success();
}
};
template <typename FromOpTy, typename ToOpTy>
struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
LogicalResult
getOperandUnitDims(LinalgOp op,
SmallVectorImpl<int64_t> &operandUnitDims) const override {
FailureOr<ContractionDimensions> maybeContractionDims =
inferContractionDims(op);
if (failed(maybeContractionDims)) {
LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
return failure();
}
ContractionDimensions contractionDims = maybeContractionDims.value();
if constexpr (reduceLeft) {
auto m = contractionDims.m[0];
SmallVector<std::pair<Value, unsigned>, 2> mOperands;
op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
if (mOperands.size() != 2)
return failure();
if (llvm::all_of(mOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
std::get<1>(mOperands[1])};
return success();
}
} else {
auto n = contractionDims.n[0];
SmallVector<std::pair<Value, unsigned>, 2> nOperands;
op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
if (nOperands.size() != 2)
return failure();
if (llvm::all_of(nOperands, [](auto pair) {
return cast<ShapedType>(std::get<0>(pair).getType())
.getShape()[std::get<1>(pair)] == 1;
})) {
operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
std::get<1>(nOperands[1])};
return success();
}
}
LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
return failure();
}
};
}
void mlir::linalg::populateContractionOpRankReducingPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
patterns
.add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
context);
patterns
.add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
context);
patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
context);
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
}