#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
using namespace mlir::vector;
namespace {
class ShapeCastOp2DDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
desc = rewriter.create<vector::InsertStridedSliceOp>(
loc, vec, desc,
i * mostMinorVectorSize, 1);
}
rewriter.replaceOp(op, desc);
return success();
}
};
class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
return failure();
auto loc = op.getLoc();
Value desc = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
loc, op.getSource(), i * mostMinorVectorSize,
mostMinorVectorSize,
1);
desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
}
rewriter.replaceOp(op, desc);
return success();
}
};
static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
int dimIdx, int initialStep = 1) {
int step = initialStep;
for (int d = dimIdx; d >= 0; d--) {
idx[d] += step;
if (idx[d] >= tp.getDimSize(d)) {
idx[d] = 0;
step = 1;
} else {
break;
}
}
}
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
return failure();
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank);
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
}
Value extract;
if (srcRank == 0) {
assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
extract = rewriter.create<vector::ExtractElementOp>(
loc, op.getSourceVectorType().getElementType(), op.getSource());
} else {
extract =
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
}
if (resRank == 0) {
assert(resIdx.empty() && "Unexpected indices for 0-D vector");
result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
} else {
result =
rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
}
rewriter.replaceOp(op, result);
return success();
}
};
class ScalableShapeCastOpRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
auto srcRank = sourceVectorType.getRank();
auto resRank = resultVectorType.getRank();
if (!isTrailingDimScalable(sourceVectorType) ||
!isTrailingDimScalable(resultVectorType)) {
return failure();
}
auto minSourceTrailingSize = sourceVectorType.getShape().back();
auto minResultTrailingSize = resultVectorType.getShape().back();
auto minExtractionSize =
std::min(minSourceTrailingSize, minResultTrailingSize);
int64_t minNumElts = 1;
for (auto size : sourceVectorType.getShape())
minNumElts *= size;
auto extractionVectorType = VectorType::get(
{minExtractionSize}, sourceVectorType.getElementType(), {true});
Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank);
Value currentResultScalableVector;
Value currentSourceScalableVector;
for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
if (!currentSourceScalableVector) {
if (srcRank != 1) {
currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
} else {
currentSourceScalableVector = op.getSource();
}
}
Value sourceSubVector = currentSourceScalableVector;
if (minExtractionSize < minSourceTrailingSize) {
sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
loc, extractionVectorType, sourceSubVector, srcIdx.back());
}
if (!currentResultScalableVector) {
if (minExtractionSize == minResultTrailingSize) {
currentResultScalableVector = sourceSubVector;
} else if (resRank != 1) {
currentResultScalableVector = rewriter.create<vector::ExtractOp>(
loc, result, llvm::ArrayRef(resIdx).drop_back());
} else {
currentResultScalableVector = result;
}
}
if (minExtractionSize < minResultTrailingSize) {
currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
loc, sourceSubVector, currentResultScalableVector, resIdx.back());
}
if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
currentResultScalableVector != result) {
result = rewriter.create<vector::InsertOp>(
loc, currentResultScalableVector, result,
llvm::ArrayRef(resIdx).drop_back());
currentResultScalableVector = {};
}
if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
currentSourceScalableVector = {};
}
incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
}
rewriter.replaceOp(op, result);
return success();
}
static bool isTrailingDimScalable(VectorType type) {
return type.getRank() >= 1 && type.getScalableDims().back() &&
!llvm::is_contained(type.getScalableDims().drop_back(), true);
}
};
}
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
ScalableShapeCastOpRewritePattern>(patterns.getContext(),
benefit);
}