#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::vector;
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = cast<VectorType>(into.getType());
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = cast<VectorType>(vector.getType());
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.getOffsets().getValue().empty())
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
Value extracted = rewriter.create<ExtractOp>(
loc, op.getDest(),
getI64SubArray(op.getOffsets(), 0,
rankRest));
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.getSource(), extracted,
getI64SubArray(op.getOffsets(), rankDiff),
getI64SubArray(op.getStrides(), 0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.getDest(),
getI64SubArray(op.getOffsets(), 0,
rankRest));
return success();
}
};
class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
void initialize() {
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.getOffsets().getValue().empty())
return failure();
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.getSource());
return success();
}
int64_t offset =
cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
Value res = op.getDest();
if (srcRank == 1) {
int nSrc = srcType.getShape().front();
int nDest = dstType.getShape().front();
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
op.getSource(), offsets);
offsets.clear();
for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
if (i < offset || i >= e || (i - offset) % stride != 0)
offsets.push_back(nDest + i);
else
offsets.push_back((i - offset) / stride);
}
rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
offsets);
return success();
}
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
if (isa<VectorType>(extractedSource.getType())) {
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
getI64SubArray(op.getOffsets(), 1),
getI64SubArray(op.getStrides(), 1));
}
res = insertOne(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
return success();
}
};
class Convert1DExtractStridedSliceIntoShuffle
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
if (op.getOffsets().getValue().size() != 1)
return failure();
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
op.getVector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
};
class Convert1DExtractStridedSliceIntoExtractInsertChain final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
Convert1DExtractStridedSliceIntoExtractInsertChain(
MLIRContext *context,
std::function<bool(ExtractStridedSliceOp)> controlFn,
PatternBenefit benefit)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
if (controlFn && !controlFn(op))
return failure();
if (op.getOffsets().getValue().size() != 1)
return failure();
int64_t offset =
cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
Location loc = op.getLoc();
SmallVector<Value> elements;
elements.reserve(size);
for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
Value result = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(op.getType()));
for (int64_t i = 0; i < size; ++i)
result = rewriter.create<InsertOp>(loc, elements[i], result, i);
rewriter.replaceOp(op, result);
return success();
}
private:
std::function<bool(ExtractStridedSliceOp)> controlFn;
};
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
void initialize() {
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
if (op.getOffsets().getValue().size() == 1)
return failure();
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.getVector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.getOffsets(), 1),
getI64SubArray(op.getSizes(), 1),
getI64SubArray(op.getStrides(), 1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
}
};
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,
PatternBenefit benefit) {
patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
patterns.getContext(), std::move(controlFn), benefit);
}
void vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit) {
populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
benefit);
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
}