#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::linalg;
static TilingInterface
createSplitPart(RewriterBase &b, Location loc, TilingInterface op,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ValueRange resultOperands, unsigned dimension,
OpFoldResult size, OpFoldResult offset,
SmallVectorImpl<Value> &results) {
SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes);
SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets);
sizesCopy[dimension] = size;
offsetsCopy[dimension] = offset;
FailureOr<TilingResult> tilingResult =
op.getTiledImplementation(b, offsetsCopy, sizesCopy);
for (auto [index, result] : llvm::enumerate(tilingResult->tiledValues)) {
SmallVector<OpFoldResult> resultOffsets, resultSizes;
if (failed(op.getResultTilePosition(b, index, offsetsCopy, sizesCopy,
resultOffsets, resultSizes)))
return nullptr;
SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
b.getIndexAttr(1));
Value inserted = b.create<tensor::InsertSliceOp>(
loc, result, resultOperands[index], resultOffsets, resultSizes,
resultStrides);
results.push_back(inserted);
}
assert(tilingResult->tiledOps.size() == 1 &&
"expected split part to return a single tiled operation");
return cast<TilingInterface>(tilingResult->tiledOps[0]);
}
std::pair<TilingInterface, TilingInterface>
linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
OpFoldResult splitPoint) {
SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter);
if (dimension >= iterationSpace.size())
return std::make_pair(op, TilingInterface());
SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range(
iterationSpace, [](const Range &range) { return range.offset; }));
SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range(
iterationSpace, [](const Range &range) { return range.size; }));
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
OpFoldResult minSplitPoint = affine::makeComposedFoldedAffineMin(
rewriter, op.getLoc(),
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2},
rewriter.getContext())
.front(),
{splitPoint, offsets[dimension], sizes[dimension]});
OpFoldResult remainingSize = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), d0 + d1 - d2,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
if (auto attr = llvm::dyn_cast_if_present<Attribute>(remainingSize)) {
if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}
SmallVector<Value> destinationTensors;
LogicalResult destStatus = tensor::getOrCreateDestinations(
rewriter, op.getLoc(), op, destinationTensors);
(void)destStatus;
assert(succeeded(destStatus) && "failed to get destination tensors");
SmallVector<Value> firstResults;
TilingInterface firstPart = createSplitPart(
rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension,
minSplitPoint, iterationSpace[dimension].offset, firstResults);
rewriter.modifyOpInPlace(op, [&]() {
unsigned numTotalOperands = op->getNumOperands();
unsigned numOutputOperands = firstResults.size();
op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
firstResults);
});
OpFoldResult totalOffset = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint});
SmallVector<Value> secondResults;
TilingInterface secondPart =
createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults,
dimension, remainingSize, totalOffset, secondResults);
if (!firstPart || !secondPart)
return {TilingInterface(), TilingInterface()};
rewriter.replaceOp(op, secondResults);
return {firstPart, secondPart};
}