#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_EMPTYTENSORELIMINATION
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::bufferization;
static bool
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
auto opResult = cast<OpResult>(val);
if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
return false;
}
}
return true;
}
static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
Operation *insertionPoint,
Operation *emptyTensorOp) {
return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
return domInfo.dominates(insertionPoint, user);
});
}
static Operation *
findValidInsertionPoint(Operation *emptyTensorOp,
const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
SmallVector<Operation *> insertionPointCandidates;
insertionPointCandidates.push_back(emptyTensorOp);
for (Value val : neededValues) {
if (auto bbArg = dyn_cast<BlockArgument>(val)) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
}
}
for (Operation *insertionPoint : insertionPointCandidates) {
if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
neededValues))
continue;
if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
continue;
return insertionPoint;
}
return nullptr;
}
LogicalResult mlir::bufferization::eliminateEmptyTensors(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);
op->walk([&](SubsetInsertionOpInterface op) {
OpOperand &source = op.getSourceOperand();
if (!state.isInPlace(source))
return WalkResult::skip();
SmallVector<Value> neededValues =
op.getValuesNeededToBuildSubsetExtraction();
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
source.get(),
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
config);
for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, neededValues);
if (!insertionPoint)
continue;
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
continue;
if (replacement.getType() != v.getType()) {
if (cast<ShapedType>(replacement.getType()).getElementType() !=
cast<ShapedType>(v.getType()).getElementType())
continue;
rewriter.setInsertionPointAfterValue(replacement);
replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
replacement);
}
rewriter.replaceOp(emptyTensorOp, replacement);
state.resetCache();
}
return WalkResult::advance();
});
return success();
}
namespace {
struct EmptyTensorElimination
: public bufferization::impl::EmptyTensorEliminationBase<
EmptyTensorElimination> {
EmptyTensorElimination() = default;
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
}
};
}
LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
Operation *op) {
auto moduleOp = dyn_cast<ModuleOp>(op);
OneShotBufferizationOptions options;
options.allowReturnAllocsFromLoops = true;
if (moduleOp)
options.bufferizeFunctionBoundaries = true;
OneShotAnalysisState state(op, options);
if (moduleOp) {
if (failed(analyzeModuleOp(moduleOp, state)))
return failure();
} else {
if (failed(analyzeOp(op, state)))
return failure();
}
return bufferization::eliminateEmptyTensors(rewriter, op, state);
}
void EmptyTensorElimination::runOnOperation() {
IRRewriter rewriter(getOperation()->getContext());
if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() {
return std::make_unique<EmptyTensorElimination>();
}