#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::linalg;
static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
for (OpOperand &operand : op.getDpsInitsMutable()) {
if (op.payloadUsesValueFromOperand(&operand))
continue;
if (operand.get().getType() != in->get().getType())
continue;
if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
continue;
return &operand;
}
return nullptr;
}
LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);
DominanceInfo domInfo;
op->walk([&](LinalgOp op) {
if (op.getNumParallelLoops() != op.getNumLoops())
return WalkResult::skip();
for (OpOperand *in : op.getDpsInputOperands()) {
if (!isa<RankedTensorType>(in->get().getType()))
continue;
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
in->get(),
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
},
config);
if (emptyTensors.empty())
continue;
OpOperand *out = getUnusedOutOperand(op, in);
if (!out)
continue;
if (!llvm::all_of(emptyTensors, [&](Value v) {
return domInfo.properlyDominates(out->get(), v.getDefiningOp());
}))
continue;
for (Value v : emptyTensors) {
assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
rewriter.replaceAllUsesWith(v, out->get());
}
rewriter.modifyOpInPlace(op, [&]() {
out->set(in->get());
in->set(emptyTensors.front());
BlockArgument outArg = op.getMatchingBlockArgument(out);
assert(outArg.getUses().empty() && "expected that out has no uses");
BlockArgument inArg = op.getMatchingBlockArgument(in);
rewriter.replaceAllUsesWith(inArg, outArg);
assert(!op.payloadUsesValueFromOperand(in) &&
"expected that the in operand is now unused");
});
state.resetCache();
}
return WalkResult::advance();
});
return success();
}