#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "concat-patterns"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
namespace mlir {
namespace tensor {
#define GEN_PASS_DEF_DECOMPOSETENSORCONCAT
#define GEN_PASS_DEF_CONCATREMOVAL
#define GEN_PASS_DEF_SIMPLIFYTENSORCONCAT
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::tensor;
namespace {
struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
Location loc = concatOp.getLoc();
FailureOr<Value> dest =
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
if (failed(dest))
return failure();
auto empty = dest->getDefiningOp<tensor::EmptyOp>();
if (!empty)
return failure();
int64_t dim = concatOp.getDim();
Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
int64_t rank = concatOp.getResultType().getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
AffineExpr sum = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> partialSums = {sum};
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
for (auto [idx, input] :
llvm::enumerate(concatOp.getInputs().drop_back())) {
sum = sum + rewriter.getAffineDimExpr(idx + 1);
partialSums.push_back(sum);
offsetStrides.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
}
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
partialSums, rewriter.getContext());
SmallVector<OpFoldResult> dimOffsets =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, loc, partialSumMap, offsetStrides);
Value result = *dest;
for (auto [input, offset] :
llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, input);
offsets[dim] = offset;
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, input, result, offsets, sizes, strides);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
concatOp, concatOp.getResultType(), result);
return success();
}
};
struct ConcatRemoval : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
static FailureOr<std::pair<Operation*, RankedTensorType>>
isOperandErasable(Value input) {
Operation* op = input.getDefiningOp();
if (!op) {
llvm::errs() << "[concat-removal] concat input is not coming from an operation\n";
return failure();
}
auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op);
if (!dpsOp) {
llvm::errs() << "[concat-removal] concat input is not DPS\n";
return failure();
}
if (dpsOp.getNumDpsInits() != 1) {
llvm::errs() << "[concat-removal] concat input DPS does not have 1 init argument\n";
return failure();
}
auto emptyTensorOp = dyn_cast<tensor::EmptyOp>(dpsOp.getDpsInits()[0].getDefiningOp());
if (!emptyTensorOp) {
llvm::errs() << "[concat-removal] concat input does not write into an empty op\n";
return failure();
}
RankedTensorType type = dyn_cast<RankedTensorType>(emptyTensorOp.getType());
if (!type) {
llvm::errs() << "[concat-removal] concat input writes into an unranked empty op\n";
return failure();
}
return std::make_pair(op, type);
}
static void
getOperandsToErase(SmallVector<Value> inputs, SmallVector<Operation*> &erasableOperands, SmallVector<bool> & erasable) {
for (auto [idx, input] : llvm::enumerate(inputs)) {
FailureOr<std::pair<Operation*, RankedTensorType>> maybeOp = isOperandErasable(input);
if (!failed(maybeOp)) {
erasableOperands.push_back((*maybeOp).first);
erasable.push_back(true);
} else {
erasableOperands.push_back(nullptr);
erasable.push_back(false);
}
}
}
static int64_t
computeConcatOffset(ConcatOp concatOp, Value operand) {
auto operands = concatOp.getInputs();
int32_t idx = -1;
for (auto [i, input] : llvm::enumerate(operands)) {
if (operand == input) {
idx = i;
break;
}
}
assert(idx >= 0);
int64_t off = 0;
for (int i = 0; i < idx; i++) {
off += cast<RankedTensorType>(operands[i].getType()).getShape()[concatOp.getDim()];
}
return off;
}
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
Location loc = concatOp.getLoc();
assert(concatOp->getNumResults() == 1);
assert(concatOp->getResultTypes().size() == 1);
if (llvm::to_vector(concatOp->getResult(0).getUsers()).size() != 1) {
llvm::errs() << "[concat-removal] concat result has more than one use\n";
return failure();
}
SmallVector<Value> sortedInputs = concatOp.getInputs();
std::sort(sortedInputs.begin(), sortedInputs.end(), [&](Value v1, Value v2) {
Operation *op1 = v1.getDefiningOp();
Operation *op2 = v2.getDefiningOp();
if (!op1) return true;
if (!op2) return false;
return op1->isBeforeInBlock(op2);
}
);
for (auto sss : sortedInputs) {
llvm::errs() << "--> ";
sss.dump();
}
SmallVector<Operation*> erasableOperands;
SmallVector<bool> erasable;
getOperandsToErase(sortedInputs, erasableOperands, erasable);
assert(erasableOperands.size() == erasable.size());
int64_t dim = concatOp.getDim();
RankedTensorType concatResultTy = cast<RankedTensorType>(concatOp.getResultType());
std::vector<int64_t> newShape = concatResultTy.getShape();
rewriter.setInsertionPointToStart(concatOp->getBlock());
auto emptyTensor = rewriter.create<tensor::EmptyOp>(loc, newShape, concatResultTy.getElementType());
SmallVector<Value> lastResult;
lastResult.push_back(emptyTensor);
for (auto [idx, operand] : llvm::enumerate(sortedInputs)) {
auto shape = cast<RankedTensorType>(operand.getType()).getShape();
SmallVector<OpFoldResult> offsets;
for (int i = 0; i < cast<RankedTensorType>(operand.getType()).getRank(); i++) {
if (i == dim) {
offsets.push_back(rewriter.getIndexAttr(computeConcatOffset(concatOp, operand)));
}
else
offsets.push_back(rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> sizes;
for (auto s : shape)
sizes.push_back(rewriter.getIndexAttr(s));
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
if (erasable[idx]) {
Operation* op = erasableOperands[idx];
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(loc, lastResult.back(), offsets, sizes, strides);
auto dpsOp = cast<DestinationStyleOpInterface>(op);
dpsOp.getDpsInitsMutable()[0].set(extractSliceOp);
rewriter.setInsertionPointAfter(op);
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(loc, dpsOp->getResult(0), lastResult.back(), offsets, sizes, strides);
lastResult.push_back(insertSliceOp);
}
else {
if (operand.getDefiningOp())
rewriter.setInsertionPointAfter(operand.getDefiningOp());
else
rewriter.setInsertionPointAfter(lastResult.back().getDefiningOp());
auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(loc, operand, lastResult.back(), offsets, sizes, strides);
lastResult.push_back(insertSliceOp);
}
}
rewriter.replaceOp(concatOp, lastResult.back());
return success();
}
};
struct SimplifyTensorConcatOp : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
struct MergeOpportunities {
SmallVector<Value> src;
SmallVector<SmallVector<tensor::ExtractSliceOp>> dst;
llvm::SmallVector<int64_t> dims;
};
static bool contiguousExtractSlicesInternal(tensor::ExtractSliceOp ext1,
tensor::ExtractSliceOp ext2,
int dim) {
for (int64_t stride : ext1.getStaticStrides())
if (stride != 1) {
LLVM_DEBUG(DBGS() << "contiguousExtractSlices: Non unit strides\n");
return false;
}
for (int64_t stride : ext2.getStaticStrides())
if (stride != 1) {
LLVM_DEBUG(DBGS() << "contiguousExtractSlices: Non unit strides\n");
return false;
}
return (ext1.getStaticOffsets()[dim] + ext1.getStaticSizes()[dim]) ==
ext2.getStaticOffsets()[dim];
}
static FailureOr<llvm::SmallVector<bool>>
contiguousExtractSlices(tensor::ExtractSliceOp ext1,
tensor::ExtractSliceOp ext2) {
if (ext1.getSource() != ext2.getSource()) {
LLVM_DEBUG(
DBGS() << "contiguousExtractSlices: extract source are different\n");
return failure();
}
if (ShapedType::isDynamicShape(ext1.getStaticSizes()) ||
ShapedType::isDynamicShape(ext1.getStaticOffsets()) ||
ShapedType::isDynamicShape(ext1.getStaticStrides())) {
LLVM_DEBUG(DBGS() << "contiguousExtractSlices: Shape is dynamic\n");
return failure();
}
if (ShapedType::isDynamicShape(ext2.getStaticSizes()) ||
ShapedType::isDynamicShape(ext2.getStaticOffsets()) ||
ShapedType::isDynamicShape(ext2.getStaticStrides())) {
LLVM_DEBUG(DBGS() << "contiguousExtractSlices: Shape is dynamic\n");
return failure();
}
if (!isa<RankedTensorType>(ext1.getSource().getType())) {
LLVM_DEBUG(DBGS() << "input type must be ranked tensor");
return failure();
}
if (!isa<RankedTensorType>(ext2.getSource().getType())) {
LLVM_DEBUG(DBGS() << "input type must be ranked tensor");
return failure();
}
auto srcType1 = cast<RankedTensorType>(ext1.getSource().getType());
auto srcType2 = cast<RankedTensorType>(ext2.getSource().getType());
int64_t srcRank1 = srcType1.getRank();
int64_t srcRank2 = srcType2.getRank();
if (srcRank1 != srcRank2) {
LLVM_DEBUG(
DBGS()
<< "ExtractSlice ops with different source type cannot be merged\n");
return failure();
}
llvm::SmallVector<bool> ret;
for (int dim = 0; dim < srcRank1; dim++)
ret.push_back(contiguousExtractSlicesInternal(ext1, ext2, dim));
return ret;
}
static FailureOr<tensor::ExtractSliceOp> getExtractSliceOp(Value val) {
Operation *op = val.getDefiningOp();
if (op == nullptr)
return failure();
tensor::ExtractSliceOp extractOp = dyn_cast<tensor::ExtractSliceOp>(op);
if (!extractOp)
return failure();
return extractOp;
}
static void
addConcatVals(llvm::SmallVector<tensor::ExtractSliceOp> &concatVals,
Value srcVal, int dim, MergeOpportunities &m) {
if (concatVals.size() > 1) {
m.src.push_back(srcVal);
m.dst.push_back(concatVals);
m.dims.push_back(dim);
}
concatVals.clear();
}
static void computeMergeOpportunities(ConcatOp concatOp,
MergeOpportunities &merge) {
llvm::SmallVector<mlir::Value> inputs;
llvm::SmallVector<llvm::SmallVector<int32_t>> indexes;
auto concatInputs = concatOp.getInputs();
uint32_t highest_idx = 0;
for (Value val : concatInputs) {
FailureOr<tensor::ExtractSliceOp> maybeExtractOp = getExtractSliceOp(val);
if (!failed(maybeExtractOp)) {
tensor::ExtractSliceOp extractOp = *maybeExtractOp;
int32_t idx = highest_idx;
Value extractSrc = extractOp.getSource();
auto it = std::find(inputs.begin(), inputs.end(), extractSrc);
if (it != inputs.end()) {
idx = it - inputs.begin();
} else {
highest_idx++;
inputs.push_back(extractSrc);
}
if (!indexes.empty() && indexes.back().back() == idx)
indexes.back().push_back(idx);
else
indexes.push_back({idx});
} else {
indexes.push_back({-1});
}
}
uint32_t concatInputIdx = 0;
for (auto vec : indexes) {
int32_t idx = vec[0];
if (idx == -1) {
concatInputIdx++;
continue;
}
Value srcVal = inputs[idx];
llvm::SmallVector<tensor::ExtractSliceOp> concatVals;
int contiguousDim = -1;
for (size_t i = 0; i < vec.size(); i++) {
FailureOr<tensor::ExtractSliceOp> maybeExtractOp =
getExtractSliceOp(concatInputs[concatInputIdx]);
assert(!failed(maybeExtractOp) && "expected to get a extract slice op");
tensor::ExtractSliceOp extractOp = *maybeExtractOp;
if (concatVals.empty()) {
concatVals.push_back(extractOp);
} else {
FailureOr<llvm::SmallVector<bool>> maybeContiguousVec =
contiguousExtractSlices(concatVals.back(), extractOp);
llvm::SmallVector<bool> contiguousVec = *maybeContiguousVec;
if (!failed(maybeContiguousVec) &&
llvm::any_of(contiguousVec, [&](bool c) { return c; })) {
LLVM_DEBUG(DBGS() << "Found contiguous extract ops:\n > ");
LLVM_DEBUG(concatVals.back().dump());
LLVM_DEBUG(llvm::dbgs() << "and\n > ");
LLVM_DEBUG(extractOp.dump());
LLVM_DEBUG(llvm::dbgs() << "\n");
concatVals.push_back(extractOp);
for (size_t dim = 0; dim < contiguousVec.size(); dim++)
if (contiguousVec[dim])
contiguousDim = dim;
} else {
addConcatVals(concatVals, srcVal, contiguousDim, merge);
if (i < vec.size() - 1)
concatVals.push_back(extractOp);
}
}
concatInputIdx++;
}
addConcatVals(concatVals, srcVal, contiguousDim, merge);
}
}
static void computeReassocIndices(Type dtype, SmallVector<int64_t> shp,
int dim, int mergeDim,
SmallVector<ReassociationIndices> &reassoc,
RankedTensorType &outTy) {
int pivot = mergeDim;
int direction = dim > mergeDim ? -1 : 1;
SmallVector<int64_t> outShape;
for (int i = 0; i < (int)shp.size(); i++) {
if (i != pivot) {
outShape.push_back(shp[i]);
if (i + direction == pivot) {
if (direction == 1)
reassoc.push_back({i, i + 1});
else
reassoc.push_back({i - 1, i});
} else {
reassoc.push_back({i});
}
}
}
outTy = RankedTensorType::get(outShape, dtype);
}
static LogicalResult rewriteConcatSlices(ConcatOp concatOp,
PatternRewriter &rewriter,
MergeOpportunities m) {
if (m.src.empty()) {
return failure();
}
int64_t dim = concatOp.getDim();
bool merged = false;
for (auto [src, dst, mergeDim] : llvm::zip_equal(m.src, m.dst, m.dims)) {
tensor::ExtractSliceOp sliceOp = dst[0];
ArrayRef<int64_t> offsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> oldSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> strides = sliceOp.getStaticStrides();
SmallVector<int64_t> outputShape;
for (int64_t i = 0; i < (int64_t)oldSizes.size(); i++) {
if (i == mergeDim) {
int s = 0;
for (auto oldSlice : dst)
s += oldSlice.getStaticSizes()[i];
outputShape.push_back(s);
} else {
outputShape.push_back(oldSizes[i]);
}
}
RankedTensorType newSliceType =
mlir::RankedTensorType::get(outputShape, rewriter.getF32Type());
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), newSliceType, src, ValueRange({}), ValueRange({}),
ValueRange({}), rewriter.getDenseI64ArrayAttr(offsets),
rewriter.getDenseI64ArrayAttr(outputShape),
rewriter.getDenseI64ArrayAttr(strides));
SmallVector<Value> newConcatOperands;
SmallVector<Value> dstVals;
for (auto op : dst)
dstVals.push_back(op->getResult(0));
Value mergedOp;
if (dim == mergeDim) {
mergedOp = newSliceOp;
} else {
SmallVector<int64_t> mergedShape(
sliceOp.getResult().getType().getShape());
for (auto op : dst)
if (op != sliceOp)
mergedShape[dim] += op.getResult().getType().getShape()[dim];
SmallVector<APInt> mergedShapeAPInt;
for (int64_t val : mergedShape)
mergedShapeAPInt.push_back(APInt(32, val));
Type dtype =
cast<TensorType>(sliceOp.getResult().getType()).getElementType();
SmallVector<ReassociationIndices> reassocIndices;
RankedTensorType collapsedOutputTy;
computeReassocIndices(dtype, mergedShape, dim, mergeDim, reassocIndices,
collapsedOutputTy);
Value collapsed = rewriter.create<tensor::CollapseShapeOp>(
sliceOp.getLoc(), collapsedOutputTy, newSliceOp, reassocIndices);
RankedTensorType expandOutputTy =
RankedTensorType::get(mergedShape, dtype);
mergedOp = rewriter.create<tensor::ExpandShapeOp>(
sliceOp.getLoc(), expandOutputTy, collapsed, reassocIndices);
}
for (auto operand : concatOp->getOperands()) {
if (std::find(dstVals.begin(), dstVals.end(), operand) !=
dstVals.end()) {
if (operand == dstVals[0])
newConcatOperands.push_back(mergedOp);
} else
newConcatOperands.push_back(operand);
}
concatOp->setOperands(newConcatOperands);
merged = true;
}
return merged ? success() : failure();
}
static void debugMergeOpportunities(MergeOpportunities m) {
if (m.src.empty())
LLVM_DEBUG(DBGS() << "Merge opportunities: (empty)\n\n");
else
LLVM_DEBUG(DBGS() << "Merge opportunities:\n");
for (auto [src, dst, dim] : llvm::zip_equal(m.src, m.dst, m.dims)) {
LLVM_DEBUG(DBGS() << "[DIMENSION " << dim << "] ");
LLVM_DEBUG(src.dump());
for (auto dstIt : dst) {
LLVM_DEBUG(DBGS() << "==> ");
LLVM_DEBUG(dstIt.dump());
}
LLVM_DEBUG(llvm::dbgs() << "\n\n");
}
}
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
MergeOpportunities merge;
computeMergeOpportunities(concatOp, merge);
debugMergeOpportunities(merge);
return rewriteConcatSlices(concatOp, rewriter, merge);
}
};
}
void mlir::tensor::populateDecomposeTensorConcatPatterns(
RewritePatternSet &patterns) {
patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
}
void mlir::tensor::populateConcatRemovalPatterns(
RewritePatternSet &patterns) {
patterns.add<ConcatRemoval>(patterns.getContext());
}
void mlir::tensor::populateSimplifyTensorConcatPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyTensorConcatOp>(patterns.getContext());
}
namespace {
struct DecomposeTensorConcatPass final
: public tensor::impl::DecomposeTensorConcatBase<
DecomposeTensorConcatPass> {
void runOnOperation() override;
};
struct ConcatRemovalPass final
: public tensor::impl::ConcatRemovalBase<
ConcatRemovalPass> {
void runOnOperation() override;
};
struct SimplifyTensorConcatPass final
: public tensor::impl::SimplifyTensorConcatBase<
SimplifyTensorConcatPass> {
void runOnOperation() override;
};
}
void DecomposeTensorConcatPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
tensor::populateDecomposeTensorConcatPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
void ConcatRemovalPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
tensor::populateConcatRemovalPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
void SimplifyTensorConcatPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
tensor::populateSimplifyTensorConcatPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> tensor::createDecomposeTensorConcatPass() {
return std::make_unique<DecomposeTensorConcatPass>();
}
std::unique_ptr<Pass> tensor::createConcatRemovalPass() {
return std::make_unique<ConcatRemovalPass>();
}
std::unique_ptr<Pass> tensor::createSimplifyTensorConcatPass() {
return std::make_unique<SimplifyTensorConcatPass>();
}