#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::bufferization;
static bool
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
if (auto bbArg = val.dyn_cast<BlockArgument>()) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
auto opResult = val.cast<OpResult>();
if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
return false;
}
}
return true;
}
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
Operation *insertionPoint,
Operation *allocTensorOp) {
for (Operation *user : allocTensorOp->getUsers())
if (!domInfo.dominates(insertionPoint, user))
return false;
return true;
}
static Operation *
findValidInsertionPoint(Operation *allocTensorOp,
const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
SmallVector<Operation *> insertionPointCandidates;
insertionPointCandidates.push_back(allocTensorOp);
for (Value val : neededValues) {
if (auto bbArg = val.dyn_cast<BlockArgument>()) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
}
}
for (Operation *insertionPoint : insertionPointCandidates) {
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
neededValues))
continue;
if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp))
continue;
return insertionPoint;
}
return nullptr;
}
LogicalResult mlir::bufferization::eliminateAllocTensors(
RewriterBase &rewriter, Operation *op, AnalysisState &state,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) {
OpBuilder::InsertionGuard g(rewriter);
WalkResult status = op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
if (!state.isInPlace(operand))
continue;
SmallVector<Value> neededValues;
if (!anchorMatchFunc(operand, neededValues))
continue;
SetVector<Value> maybeAllocTensor =
state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
OpResult opResult = val.dyn_cast<OpResult>();
if (!opResult)
return true;
SmallVector<OpOperand *> opOperands =
state.getAliasingOpOperand(opResult);
if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
return state.isInPlace(*operand);
}))
return true;
return !llvm::all_of(opOperands, [&](OpOperand *operand) {
return state.areEquivalentBufferizedValues(operand->get(),
opResult);
});
});
if (maybeAllocTensor.size() != 1 ||
!maybeAllocTensor.front().getDefiningOp<AllocTensorOp>())
return WalkResult::skip();
Value allocTensor = maybeAllocTensor.front();
Operation *insertionPoint =
findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues);
if (!insertionPoint)
continue;
rewriter.setInsertionPoint(insertionPoint);
Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand);
if (!replacement)
continue;
rewriter.replaceOp(allocTensor.getDefiningOp(), replacement);
}
return WalkResult::advance();
});
return failure(status.wasInterrupted());
}
LogicalResult
mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep(
RewriterBase &rewriter, Operation *op, AnalysisState &state) {
return eliminateAllocTensors(
rewriter, op, state,
[&](OpOperand &operand, SmallVector<Value> &neededValues) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
if (!insertSliceOp)
return false;
if (&operand != &insertSliceOp->getOpOperand(0) )
return false;
neededValues.append(insertSliceOp.getOffsets().begin(),
insertSliceOp.getOffsets().end());
neededValues.append(insertSliceOp.getSizes().begin(),
insertSliceOp.getSizes().end());
neededValues.append(insertSliceOp.getStrides().begin(),
insertSliceOp.getStrides().end());
neededValues.push_back(insertSliceOp.getDest());
return true;
},
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner());
auto extractOp = b.create<tensor::ExtractSliceOp>(
loc, insertOp.getSourceType(), insertOp.getDest(),
insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
insertOp.getMixedStrides());
return extractOp.getResult();
});
}
namespace {
struct AllocTensorElimination
: public AllocTensorEliminationBase<AllocTensorElimination> {
AllocTensorElimination() = default;
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
}
};
}
void AllocTensorElimination::runOnOperation() {
Operation *op = getOperation();
OneShotBufferizationOptions options;
OneShotAnalysisState state(op, options);
if (failed(analyzeOp(op, state))) {
signalPassFailure();
return;
}
IRRewriter rewriter(op->getContext());
if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep(
rewriter, op, state)))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() {
return std::make_unique<AllocTensorElimination>();
}