#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
static void
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
const SliceOptions::TransitiveFilter &filter = nullptr) {
if (!op)
return;
if (filter && !filter(op))
return;
for (Region ®ion : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice->count(&blockOp) == 0)
getForwardSliceImpl(&blockOp, forwardSlice, filter);
for (Value result : op->getResults()) {
for (Operation *userOp : result.getUsers())
if (forwardSlice->count(userOp) == 0)
getForwardSliceImpl(userOp, forwardSlice, filter);
}
forwardSlice->insert(op);
}
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
const ForwardSliceOptions &options) {
getForwardSliceImpl(op, forwardSlice, options.filter);
if (!options.inclusive) {
forwardSlice->remove(op);
}
SmallVector<Operation *, 0> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
const SliceOptions &options) {
for (Operation *user : root.getUsers())
getForwardSliceImpl(user, forwardSlice, options.filter);
SmallVector<Operation *, 0> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
}
static void getBackwardSliceImpl(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
return;
if (options.filter && !options.filter(op))
return;
for (const auto &en : llvm::enumerate(op->getOperands())) {
auto operand = en.value();
if (auto *definingOp = operand.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, options);
} else if (auto blockArg = dyn_cast<BlockArgument>(operand)) {
if (options.omitBlockArguments)
continue;
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
}
}
backwardSlice->insert(op);
}
void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
getBackwardSliceImpl(op, backwardSlice, options);
if (!options.inclusive) {
backwardSlice->remove(op);
}
}
void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
if (Operation *definingOp = root.getDefiningOp()) {
getBackwardSlice(definingOp, backwardSlice, options);
return;
}
Operation *bbAargOwner = cast<BlockArgument>(root).getOwner()->getParentOp();
getBackwardSlice(bbAargOwner, backwardSlice, options);
}
SetVector<Operation *>
mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
const ForwardSliceOptions &forwardSliceOptions) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
slice.insert(backwardSlice.begin(), backwardSlice.end());
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return topologicalSort(slice);
}
static bool dependsOnCarriedVals(Value value,
ArrayRef<BlockArgument> iterCarriedArgs,
Operation *ancestorOp) {
SetVector<Operation *> slice;
BackwardSliceOptions sliceOptions;
sliceOptions.filter = [&](Operation *op) {
return !ancestorOp->isAncestor(op);
};
getBackwardSlice(value, &slice, sliceOptions);
SmallPtrSet<Value, 8> iterCarriedValSet(iterCarriedArgs.begin(),
iterCarriedArgs.end());
if (iterCarriedValSet.contains(value))
return true;
for (Operation *op : slice)
for (Value operand : op->getOperands())
if (iterCarriedValSet.contains(operand))
return true;
return false;
}
Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
unsigned redPos,
SmallVectorImpl<Operation *> &combinerOps) {
assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds");
BlockArgument redCarriedVal = iterCarriedArgs[redPos];
if (!redCarriedVal.hasOneUse())
return nullptr;
Operation *combinerOp = *redCarriedVal.getUsers().begin();
if (combinerOp->getNumOperands() != 2)
return nullptr;
Value reducedVal = combinerOp->getOperand(0) == redCarriedVal
? combinerOp->getOperand(1)
: combinerOp->getOperand(0);
Operation *redRegionOp =
iterCarriedArgs.front().getOwner()->getParent()->getParentOp();
if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp))
return nullptr;
while (!combinerOp->mightHaveTrait<OpTrait::IsTerminator>()) {
if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 ||
!combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp)
return nullptr;
combinerOps.push_back(combinerOp);
combinerOp = *combinerOp->getUsers().begin();
}
if (combinerOps.size() != 1)
return nullptr;
Operation *terminatorOp = combinerOp;
if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
return nullptr;
return reducedVal;
}