#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
#include <utility>
#define DEBUG_TYPE "linalg-transforms"
using namespace mlir;
using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter,
Operation *op) {
return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op)
.Case<scf::ForOp>([&](scf::ForOp forOp) {
scf::ForOp partialIteration;
if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
partialIteration)))
return partialIteration->getResults();
assert(!partialIteration && "expected that loop was not peeled");
return forOp->getResults();
})
.Default([&](Operation *op) { return op->getResults(); });
}
void mlir::linalg::peelLoops(RewriterBase &rewriter,
ArrayRef<scf::ForOp> loops) {
for (auto loopOp : loops)
peelLoop(rewriter, loopOp);
}
#ifndef NDEBUG
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
bool found = false;
for (AffineExpr e : map.getResults()) {
if (!e.isFunctionOfDim(dim))
continue;
if (found)
return false;
found = true;
}
return true;
}
#endif
static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
int64_t dim) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
AffineExpr expr = map.getResult(i);
if (!expr.isFunctionOfDim(dim))
continue;
return i;
}
return std::nullopt;
}
static FailureOr<SmallVector<std::optional<int64_t>>>
packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps,
SmallVectorImpl<utils::IteratorType> &iteratorTypes,
int64_t dim) {
int64_t newDim = iteratorTypes.size();
iteratorTypes.push_back(iteratorTypes[dim]);
SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
indexingMaps.size(), std::nullopt);
SmallVector<AffineMap> newMaps;
for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
++operandIdx) {
AffineMap map = indexingMaps[operandIdx];
assert(map.getNumDims() == newDim && "num dims invariant violation");
map = map.shiftDims(1, newDim);
assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
"num results invariant violation");
auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
if (!maybeOperandDimensionToPack.has_value()) {
newMaps.push_back(map);
continue;
}
if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
return failure();
map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
map.getNumResults());
newMaps.push_back(map);
packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
}
indexingMaps = newMaps;
return packedDimPerIndexingMap;
}
namespace {
struct PackedOperandsDim {
OpFoldResult packedSize;
SmallVector<std::optional<int64_t>> packedDimForEachOperand;
};
struct PackedOperandsDimList {
void pushBack(PackedOperandsDim &&packedOperandsDims) {
spec.emplace_back(packedOperandsDims);
}
SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
private:
SmallVector<PackedOperandsDim> spec;
};
}
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
auto packedTensorType =
cast<RankedTensorType>(packOp->getResultTypes().front());
if (llvm::any_of(packOp.getStaticInnerTiles(),
[](int64_t size) { return ShapedType::isDynamic(size); })) {
return rewriter.notifyMatchFailure(
packOp,
"non-static shape NYI, needs a more powerful tensor.expand_shape op");
}
Location loc = packOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(packOp);
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getPackInverseDestPerm(packOp);
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
rewriter.getIndexAttr(0));
for (auto [pos, innerSize] :
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
int outerPos =
packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
OpFoldResult origSize =
tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
OpFoldResult outerSize =
tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
AffineExpr s0, d0, d1;
bindDims(rewriter.getContext(), d0, d1);
bindSymbols(rewriter.getContext(), s0);
auto map = AffineMap::get(2, 1, d0 * s0 - d1);
highs[pos] = affine::makeComposedFoldedAffineApply(
rewriter, loc, map, {outerSize, origSize, innerSize});
}
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
packingMetadata.reassociations);
Value paddingValue = packOp.getPaddingValue();
if (!paddingValue) {
paddingValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
}
auto padOp =
rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
highs, paddingValue, false);
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
DBGS() << "insertPositions: ");
DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
DBGS() << "outerPositions: ");
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
llvm::interleaveComma(ri, llvm::dbgs() << "|");
});
DBGSNL();
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
if (packOp.isLikePad()) {
SliceVerificationResult rankReduces =
isRankReducedType(packedTensorType, padOp.getResultType());
if (rankReduces == SliceVerificationResult::Success) {
auto emptyOp =
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> ones(packOp.getDestRank(),
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, padOp, emptyOp,
zeros, sizes,
ones);
LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
rewriter.replaceOp(packOp, insertSliceOp->getResults());
return LowerPackResult{padOp, nullptr,
nullptr};
}
}
auto expandShapeResultType =
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeResultType, padOp.getResult(),
packingMetadata.reassociations);
SmallVector<int64_t> transpPerm =
invertPermutationVector(packedToStripMinedShapePerm);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "reshape op: " << reshapeOp; DBGSNL();
llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
rewriter.replaceOp(packOp, transposeOp->getResults());
return LowerPackResult{padOp, reshapeOp, transposeOp};
}
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
tensor::UnPackOp unPackOp) {
Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
RankedTensorType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
if (unPackOp.isLikeUnPad()) {
ArrayRef<int64_t> destShape = destTensorType.getShape();
SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, unPackOp.getSource(),
SmallVector<OpFoldResult>(packedRank, zero), sizes,
SmallVector<OpFoldResult>(packedRank, one));
rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
return LowerUnPackOpResult{nullptr, nullptr,
nullptr, extractSliceOp};
}
PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
RankedTensorType stripMinedTensorType =
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMinedTensorType, packingMetadata.reassociations);
SmallVector<OpFoldResult, 4> dims =
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, dims, stripMinedTensorType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
DBGS() << "insertPositions: ");
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
llvm::interleaveComma(ri, llvm::dbgs() << "|");
});
DBGSNL();
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, transposeOp->getResult(0),
packingMetadata.reassociations);
int64_t destRank = destTensorType.getRank();
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));
auto copyOp = rewriter.create<linalg::CopyOp>(
loc, extractSliceOp->getResult(0), unPackOp.getDest());
rewriter.replaceOp(unPackOp, copyOp->getResults());
return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
}
SmallVector<int64_t>
PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
SmallVector<int64_t> res;
for (auto &i : spec) {
if (!i.packedDimForEachOperand[operandPos].has_value())
continue;
res.push_back(i.packedDimForEachOperand[operandPos].value());
}
return res;
}
SmallVector<OpFoldResult>
PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
SmallVector<OpFoldResult> res;
for (auto &i : spec) {
if (!i.packedDimForEachOperand[operandPos].has_value())
continue;
res.push_back(i.packedSize);
}
return res;
}
FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> packedSizes) {
if (packedSizes.size() != linalgOp.getNumLoops()) {
return rewriter.notifyMatchFailure(linalgOp,
"incorrect number of pack sizes");
}
Location loc = linalgOp->getLoc();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iteratorTypes =
linalgOp.getIteratorTypesArray();
LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
DBGSNL(););
SmallVector<tensor::PackOp> packOps;
SmallVector<tensor::UnPackOp> unPackOps;
PackedOperandsDimList listOfPackedOperandsDim;
for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
if (maybeConstant.has_value() && maybeConstant.value() == 0)
continue;
PackedOperandsDim packedOperandsDims;
packedOperandsDims.packedSize = packedSizes[i];
FailureOr<SmallVector<std::optional<int64_t>>>
maybePackedDimForEachOperand =
packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
if (failed(maybePackedDimForEachOperand))
return failure();
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
LLVM_DEBUG(
DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
<< "\n";
llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
DBGS() << "packedDimForEachOperand: ");
DBGSNL(););
}
SmallVector<Value> inputsAndInits, results;
SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
for (const auto &operandsList : {inputOperands, initOperands}) {
for (OpOperand *opOperand : operandsList) {
int64_t pos = opOperand->getOperandNumber();
Value operand = opOperand->get();
SmallVector<int64_t> innerPos =
listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
SmallVector<OpFoldResult> innerPackSizes =
listOfPackedOperandsDim.extractPackSizesForOperand(pos);
LLVM_DEBUG(
DBGS() << "operand: " << operand << "\n";
llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
DBGSNL(););
if (innerPackSizes.empty()) {
inputsAndInits.push_back(operand);
continue;
}
Value dest = tensor::PackOp::createDestinationTensor(
rewriter, loc, operand, innerPackSizes, innerPos,
{});
ShapedType operandType = cast<ShapedType>(operand.getType());
bool areConstantTiles =
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
return getConstantIntValue(tile).has_value();
});
if (areConstantTiles && operandType.hasStaticShape() &&
!tensor::PackOp::requirePaddingValue(
operandType.getShape(), innerPos,
cast<ShapedType>(dest.getType()).getShape(), {},
innerPackSizes)) {
packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes));
} else {
auto zeroAttr =
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes, zero));
}
inputsAndInits.push_back(packOps.back());
}
}
ValueRange inputs =
ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
ValueRange inits =
ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
iteratorTypes);
packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
for (OpResult result : packedLinalgOp->getResults()) {
int64_t resultNum = result.getResultNumber();
tensor::PackOp maybePackedInit =
inits[resultNum].getDefiningOp<tensor::PackOp>();
if (!maybePackedInit) {
results.push_back(result);
continue;
}
unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
results.push_back(unPackOps.back());
}
rewriter.replaceOp(linalgOp, results);
return PackResult{packOps,
cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
unPackOps};
}
static RankedTensorType permuteShape(RankedTensorType tensorType,
ArrayRef<int64_t> permutationVector) {
SmallVector<int64_t> shape(tensorType.getShape());
applyPermutationToVector(shape, permutationVector);
return RankedTensorType::Builder(tensorType).setShape(shape);
}
static LinalgOp transposeOneLinalgOperandAndReplace(
RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
ArrayRef<int64_t> permutation, Value transposedValue) {
assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
auto tensorType = permuteShape(
cast<RankedTensorType>(opOperand.get().getType()), permutation);
(void)tensorType;
assert(tensorType == transposedValue.getType() &&
"expected tensor type mismatch");
SmallVector<unsigned> tmpTransposition = llvm::to_vector(
llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
AffineMap permutationMap =
AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
AffineMap transposedMap =
permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
SmallVector<Value> operands = linalgOp->getOperands();
operands[opOperand.getOperandNumber()] = transposedValue;
ValueRange operandsRef(operands);
auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
linalgOp->getLoc(),
operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
operandsRef.take_front(linalgOp.getNumDpsInputs()),
operandsRef.drop_front(linalgOp.getNumDpsInputs()),
indexingMaps,
linalgOp.getIteratorTypesArray());
transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
}
FailureOr<PackTransposeResult>
linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
ArrayRef<int64_t> outerPerm,
ArrayRef<int64_t> innerPerm) {
Location loc = linalgOp.getLoc();
rewriter.setInsertionPoint(packOp);
tensor::PackOp transposedPackOp =
packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
if (!packOp.getResult().hasOneUse())
return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
OpOperand &packUse = *packOp->getUses().begin();
if (packUse.getOwner() != linalgOp) {
return rewriter.notifyMatchFailure(
linalgOp, "not a single use by the LinalgOp target");
}
if (maybeUnPackOp &&
(!linalgOp.isDpsInit(&packUse) ||
maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
return rewriter.notifyMatchFailure(linalgOp,
"not produced by the LinalgOp target");
}
int64_t numLeadingDims = packOp.getSourceRank();
int64_t numTrailingDims = packOp.getInnerDimsPos().size();
SmallVector<int64_t> permutation(outerPerm);
if (permutation.empty())
llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
if (innerPerm.empty()) {
llvm::append_range(
permutation,
llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
} else {
llvm::append_range(permutation,
llvm::map_range(innerPerm, [&](int64_t pos) {
return numLeadingDims + pos;
}));
}
if (!isPermutationVector(permutation))
return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
int64_t packUseOperandNumber = packUse.getOperandNumber();
rewriter.setInsertionPoint(linalgOp);
linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
tensor::UnPackOp transposedUnPackOp;
if (maybeUnPackOp) {
OpOperand &opOperand =
transposedLinalgOp->getOpOperand(packUseOperandNumber);
OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
rewriter.setInsertionPoint(maybeUnPackOp);
transposedUnPackOp = maybeUnPackOp.createTransposedClone(
rewriter, loc, transposedResult, innerPerm, outerPerm);
rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
}
rewriter.replaceOp(packOp, transposedPackOp->getResults());
return PackTransposeResult{transposedPackOp, transposedLinalgOp,
transposedUnPackOp};
}
FailureOr<PackResult>
linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<OpFoldResult> mnkPackedSizes,
ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
ArrayRef<int64_t> mnkOrder) {
assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
assert((mnkPaddedSizesNextMultipleOf.empty() ||
mnkPaddedSizesNextMultipleOf.size() == 3) &&
"num of packing sizes next multiple should be empty or of size 3");
assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
assert(isPermutationVector(mnkOrder) && "expected a permutation");
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
<< numLoops << "\nin: " << linalgOp << "\n");
return rewriter.notifyMatchFailure(
linalgOp, "need 3+ loops to find a matmul to pack");
}
int64_t numPackedDims = mnkPackedSizes.size();
SmallVector<int64_t> mmnnkkPos(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
SmallVector<OpFoldResult> packedSizes(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
paddedSizesNextMultipleOf[mnkOrder[i]] =
mnkPaddedSizesNextMultipleOf.empty() ? 0
: mnkPaddedSizesNextMultipleOf[i];
}
FailureOr<ContractionDimensions> maybeDimensions =
inferContractionDims(linalgOp);
if (failed(maybeDimensions)) {
LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
<< "\n");
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer matmul iterators");
}
int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
kPos = maybeDimensions->k.back();
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "Start packing generic op greedily with (m@" << mPos
<< ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
<< "\n";);
auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
if (!genericOp) {
FailureOr<GenericOp> generalizeResult =
generalizeNamedOp(rewriter, linalgOp);
assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
genericOp = *generalizeResult;
}
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
FailureOr<GenericOp> interchangeResult =
interchangeGenericOp(rewriter, genericOp, unsignedPerm);
assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
genericOp = *interchangeResult;
LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
SmallVector<Range, 4> loopRanges =
cast<LinalgOp>(genericOp.getOperation())
.createLoopRanges(rewriter, genericOp.getLoc());
LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
DBGS() << "paddedSizesNextMultipleOf: ");
DBGSNL(););
LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
[](Range r) { llvm::dbgs() << r.size; });
DBGSNL(););
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
if (paddedSizesNextMultipleOf[i] == 0) {
adjustedPackedSizes.push_back(packedSizes[i]);
continue;
}
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
{loopRanges[adjustedPackedSizes.size()].size,
rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
}
LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
DBGS() << "adjustedPackedSizes: ");
DBGSNL(););
return pack(rewriter, genericOp, adjustedPackedSizes);
}
LinalgTilingOptions &
mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(
&op->getParentOfType<func::FuncOp>().getBody().front());
return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
return v;
}));
};
return *this;
}
LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
memref::CopyOp copyOp, PatternRewriter &rewriter) const {
return vectorizeCopy(rewriter, copyOp);
}
Value GeneralizePadOpPattern::createFillOrGenerateOp(
RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
const SmallVector<Value> &dynSizes) const {
auto padValue = padOp.getConstantPaddingValue();
if (padValue)
return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
auto generateOp = rewriter.create<tensor::GenerateOp>(
padOp.getLoc(), padOp.getResultType(), dynSizes);
IRMapping bvm;
padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
return generateOp;
}
LogicalResult
GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
auto getIdxValue = [&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
return val;
return rewriter
.create<arith::ConstantIndexOp>(
padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
.getResult();
};
auto resultType = padOp.getResultType();
SmallVector<Value> dynSizes;
SmallVector<int64_t> staticSizes;
for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
if (resultType.isDynamicDim(dim)) {
auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
padOp.getSource(), dim));
auto plusLow = rewriter.createOrFold<arith::AddIOp>(
padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
dynSizes.push_back(plusHigh);
}
staticSizes.push_back(resultType.getDimSize(dim));
}
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
return success();
auto sourceType = padOp.getSourceType();
SmallVector<OpFoldResult> srcSizes =
tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
SmallVector<OpFoldResult> strides(sourceType.getRank(),
rewriter.getIndexAttr(1));
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
strides);
return success();
}
LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
if (!sliceOp.hasUnitStride())
return failure();
auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
bool zeroSliceGuard = true;
if (controlFn) {
if (std::optional<bool> control = controlFn(sliceOp))
zeroSliceGuard = *control;
else
return failure();
}
FailureOr<TilingResult> tilingResult =
tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), zeroSliceGuard);
if (failed(tilingResult))
return failure();
rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
return success();
}
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
tensor::PackOp packOp) {
Value input = packOp.getSource();
if (!packOp.getPaddingValue()) {
return input;
}
Location loc = packOp.getLoc();
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
[](int64_t val) { return val == 1; }));
SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
packOp.getDimAndTileMapping();
for (int64_t dim = 0; dim < inputRank; ++dim) {
int64_t size = inputType.getDimSize(dim);
if (!tileAndPosMapping.count(dim)) {
paddedShape.push_back(size);
continue;
}
std::optional<int64_t> tileSize =
getConstantIntValue(tileAndPosMapping.lookup(dim));
assert(tileSize.has_value() && "dynamic inner tile size is not supported");
paddedShape.push_back(tileSize.value());
}
auto resultType =
RankedTensorType::get(paddedShape, inputType.getElementType());
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
false, loc, builder);
}
static SmallVector<int64_t>
getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
constexpr int64_t kNonTiledMarker = -1;
SmallVector<int64_t> vec(rank, kNonTiledMarker);
for (auto [index, value] : llvm::enumerate(perm))
vec[value] = index;
SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
vec, [&](int64_t v) { return v != kNonTiledMarker; }));
return invertPermutationVector(normalizedPerm);
}
static SmallVector<int64_t>
getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> rankReducedOuterDimsPerm;
SmallVector<int64_t> outerDims;
SmallVector<int64_t> innerDims;
int64_t dim = 0;
int64_t unpackedRank = shape.size();
for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
if (llvm::is_contained(innerDimsPos, i)) {
innerDims.push_back(dim++);
continue;
}
if (shape[i] == 1)
continue;
outerDims.push_back(dim++);
if (!outerDimsPerm.empty())
rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
}
SmallVector<int64_t> innerPerm =
getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
applyPermutationToVector<int64_t>(innerDims, innerPerm);
SmallVector<int64_t> perm = outerDims;
rankReducedOuterDimsPerm =
getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
if (!rankReducedOuterDimsPerm.empty())
applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
perm.append(innerDims);
return perm;
}
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
if (llvm::any_of(packOp.getMixedTiles(),
[](OpFoldResult tile) { return tile.is<Value>(); })) {
return rewriter.notifyMatchFailure(packOp,
"require inner tile sizes being static");
}
auto innerDimsPos = packOp.getInnerDimsPos();
int64_t srcRank = packOp.getSourceRank();
auto destShape = packOp.getDestType().getShape();
if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
return destShape[index] != 1;
})) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
Location loc = packOp.getLoc();
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
.value_or(ShapedType::kDynamic));
readSizes.push_back(dimAndTileMapping[i]);
continue;
}
if (ShapedType::isDynamic(inputShape[i])) {
readSizes.push_back(
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
} else {
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
}
if (inputShape[i] != 1)
readShape.push_back(inputShape[i]);
}
Type elemType = packOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShape, elemType);
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, innerDimsPos, packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
SmallVector<int64_t> transpShape = readShape;
applyPermutationToVector<int64_t>(transpShape, perm);
Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
int64_t destRank = packOp.getDestRank();
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
SmallVector<OpFoldResult> writeSizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
writeSizes, writeStrides);
rewriter.replaceOp(packOp, insert.getResult());
return success();
}
LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
int64_t srcRank = unpackOp.getSourceRank();
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
return srcShape[index] != 1;
})) {
return rewriter.notifyMatchFailure(
unpackOp,
"require the tiled outer dimensions of the result are all 1s");
}
Location loc = unpackOp.getLoc();
Value source = unpackOp.getSource();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<int64_t> readShape;
SmallVector<Value> dynamicDims;
for (auto i : llvm::seq<unsigned>(0, destRank)) {
if (dimAndTileMapping.count(i)) {
readSizes.push_back(oneIdxAttr);
continue;
}
if (ShapedType::isDynamic(srcShape[i])) {
Value dynamicDim =
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
readSizes.push_back(dynamicDim);
dynamicDims.push_back(dynamicDim);
} else {
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
}
if (srcShape[i] != 1)
readShape.push_back(srcShape[i]);
}
auto mixedTiles = unpackOp.getMixedTiles();
readSizes.append(mixedTiles.begin(), mixedTiles.end());
auto tileShape = srcShape.drop_front(destRank);
readShape.append(tileShape.begin(), tileShape.end());
Type elemType = unpackOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShape, elemType);
Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
perm = invertPermutationVector(perm);
SmallVector<int64_t> transpShape(readShape);
applyPermutationToVector<int64_t>(transpShape, perm);
Value empty =
rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
int numLoops = transpShape.size();
SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
SmallVector<OpFoldResult> tileSizes;
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
for (auto i : llvm::seq<unsigned>(0, destRank)) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
tileSizes.push_back(
tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
}
auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
SmallVector<OpFoldResult> writeSizes;
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
for (int i = 0, idx = 0; i < destRank; ++i) {
if (dimAndTileMapping.count(i) || destShape[i] != 1)
writeSizes.push_back(tileSizes[idx++]);
else
writeSizes.push_back(oneIdxAttr);
}
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
writeStrides);
rewriter.replaceOp(unpackOp, insert.getResult());
return success();
}
template <typename Conv2DOp, typename Conv1DOp>
FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
if (convOp.hasPureBufferSemantics())
return failure();
Value input = convOp.getInputs().front();
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
auto [khIndex, kwIndex, ohIndex, owIndex] =
TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
convOp)
.Case([&](linalg::Conv2DNhwcHwcfOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::Conv2DNchwFchwOp op) {
return std::make_tuple(2, 3, 2, 3);
})
.Case([&](linalg::PoolingNhwcSumOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::PoolingNchwSumOp op) {
return std::make_tuple(0, 1, 2, 3);
})
.Case([&](linalg::PoolingNhwcMaxOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::PoolingNhwcMinOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::PoolingNchwMaxOp op) {
return std::make_tuple(0, 1, 2, 3);
})
.Default([&](Operation *op) {
llvm_unreachable("unexpected conv2d/pool2d operation.");
return std::make_tuple(0, 0, 0, 0);
});
int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
return failure();
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
RankedTensorType newKernelType =
RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
RankedTensorType newOutputType =
RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
Location loc = convOp.getLoc();
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, input, newInputType);
Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, kernel, newKernelType);
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);
auto strides =
llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
auto dilations =
llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = rewriter.create<Conv1DOp>(
loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);
return conv1DOp;
}
template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
Conv1DNcwFcwOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
PoolingNwcSumOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
PoolingNcwSumOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
PoolingNwcMaxOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<
PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
PoolingNwcMinOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<
PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
PoolingNcwMaxOp>;
FailureOr<DepthwiseConv1DNwcWcOp>
DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
if (convOp.hasPureBufferSemantics())
return failure();
Value input = convOp.getInputs().front();
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
int64_t ohSize = outputShape[1], owSize = outputShape[2];
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
return failure();
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
RankedTensorType newKernelType =
RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
RankedTensorType newOutputType =
RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
Location loc = convOp.getLoc();
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, input, newInputType);
Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, kernel, newKernelType);
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);
auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
auto dilations =
llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);
return conv1DOp;
}
FailureOr<Conv1DOp>
DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
PatternRewriter &rewriter) const {
if (convOp.hasPureBufferSemantics())
return failure();
Value input = convOp.getInputs().front();
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
int64_t ohSize = outputShape[0], owSize = outputShape[1];
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
return failure();
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
RankedTensorType newKernelType =
RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
RankedTensorType newOutputType =
RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
Location loc = convOp.getLoc();
Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, input, newInputType);
Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, kernel, newKernelType);
Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
rewriter, loc, output, newOutputType);
auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
ValueRange{newInput, newKernel},
ValueRange{newOutput});
Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
rewriter, loc, conv1DOp.getResult(0), output);
rewriter.replaceOp(convOp, inserted);
return conv1DOp;
}
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>,
DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
patterns.getContext(), benefit);
patterns.add<
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
PoolingNwcMaxUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
PoolingNwcMinUnsignedOp>,
DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
patterns.getContext(), benefit);
}