#include "mlir/Transforms/FoldUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
static Region *
getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
Block *insertionBlock) {
while (Region *region = insertionBlock->getParent()) {
auto *parentOp = region->getParentOp();
if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
!parentOp->getBlock())
return region;
auto *interface = interfaces.getInterfaceFor(parentOp);
if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
return region;
insertionBlock = parentOp->getBlock();
}
llvm_unreachable("expected valid insertion region");
}
static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
Attribute value, Type type,
Location loc) {
auto insertPt = builder.getInsertionPoint();
(void)insertPt;
if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
assert(insertPt == builder.getInsertionPoint());
assert(matchPattern(constOp, m_Constant()));
return constOp;
}
return nullptr;
}
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
if (inPlaceUpdate)
*inPlaceUpdate = false;
if (isFolderOwnedConstant(op)) {
Block *opBlock = op->getBlock();
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
op->moveBefore(&opBlock->front());
op->setLoc(erasedFoldedLocation);
}
return failure();
}
SmallVector<Value, 8> results;
if (failed(tryToFold(op, results)))
return failure();
if (results.empty()) {
if (inPlaceUpdate)
*inPlaceUpdate = true;
if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
rewriter.getListener())) {
rewriteListener->notifyOperationModified(op);
}
return success();
}
notifyRemoval(op);
rewriter.replaceOp(op, results);
return success();
}
bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
Block *opBlock = op->getBlock();
if (isFolderOwnedConstant(op)) {
if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
op->moveBefore(&opBlock->front());
op->setLoc(erasedFoldedLocation);
}
return true;
}
if (!constValue) {
matchPattern(op, m_Constant(&constValue));
assert(constValue && "expected `op` to be a constant");
} else {
#ifndef NDEBUG
Attribute expectedValue;
matchPattern(op, m_Constant(&expectedValue));
assert(
expectedValue == constValue &&
"provided constant value was not the expected value of the constant");
#endif
}
Region *insertRegion = getInsertionRegion(interfaces, opBlock);
auto &uniquedConstants = foldScopes[insertRegion];
Operation *&folderConstOp = uniquedConstants[std::make_tuple(
op->getDialect(), constValue, *op->result_type_begin())];
if (folderConstOp) {
notifyRemoval(op);
rewriter.replaceOp(op, folderConstOp->getResults());
folderConstOp->setLoc(erasedFoldedLocation);
return false;
}
Block *insertBlock = &insertRegion->front();
if (opBlock != insertBlock || (&insertBlock->front() != op &&
!isFolderOwnedConstant(op->getPrevNode()))) {
op->moveBefore(&insertBlock->front());
op->setLoc(erasedFoldedLocation);
}
folderConstOp = op;
referencedDialects[op].push_back(op->getDialect());
return true;
}
void OperationFolder::notifyRemoval(Operation *op) {
auto it = referencedDialects.find(op);
if (it == referencedDialects.end())
return;
Attribute constValue;
matchPattern(op, m_Constant(&constValue));
assert(constValue);
auto &uniquedConstants =
foldScopes[getInsertionRegion(interfaces, op->getBlock())];
auto type = op->getResult(0).getType();
for (auto *dialect : it->second)
uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
referencedDialects.erase(it);
}
void OperationFolder::clear() {
foldScopes.clear();
referencedDialects.clear();
}
Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
Attribute value, Type type) {
auto *insertRegion = getInsertionRegion(interfaces, block);
auto &entry = insertRegion->front();
rewriter.setInsertionPoint(&entry, entry.begin());
auto &uniquedConstants = foldScopes[insertRegion];
Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
type, erasedFoldedLocation);
return constOp ? constOp->getResult(0) : Value();
}
bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
return referencedDialects.count(op);
}
LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVectorImpl<Value> &results) {
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(foldResults)) ||
failed(processFoldResults(op, results, foldResults)))
return failure();
return success();
}
LogicalResult
OperationFolder::processFoldResults(Operation *op,
SmallVectorImpl<Value> &results,
ArrayRef<OpFoldResult> foldResults) {
if (foldResults.empty())
return success();
assert(foldResults.size() == op->getNumResults());
auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
auto &entry = insertRegion->front();
rewriter.setInsertionPoint(&entry, entry.begin());
auto &uniquedConstants = foldScopes[insertRegion];
auto *dialect = op->getDialect();
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
results.emplace_back(repl);
continue;
}
auto res = op->getResult(i);
Attribute attrRepl = foldResults[i].get<Attribute>();
if (auto *constOp =
tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
res.getType(), erasedFoldedLocation)) {
Block *opBlock = op->getBlock();
if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
constOp->moveBefore(&opBlock->front());
results.push_back(constOp->getResult(0));
continue;
}
for (Operation &op : llvm::make_early_inc_range(
llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
notifyRemoval(&op);
rewriter.eraseOp(&op);
}
results.clear();
return failure();
}
return success();
}
Operation *
OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
Dialect *dialect, Attribute value,
Type type, Location loc) {
auto constKey = std::make_tuple(dialect, value, type);
Operation *&constOp = uniquedConstants[constKey];
if (constOp) {
if (loc != constOp->getLoc())
constOp->setLoc(erasedFoldedLocation);
return constOp;
}
if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
return nullptr;
auto *newDialect = constOp->getDialect();
if (newDialect == dialect) {
referencedDialects[constOp].push_back(dialect);
return constOp;
}
auto newKey = std::make_tuple(newDialect, value, type);
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
notifyRemoval(constOp);
rewriter.eraseOp(constOp);
referencedDialects[existingOp].push_back(dialect);
if (loc != existingOp->getLoc())
existingOp->setLoc(erasedFoldedLocation);
return constOp = existingOp;
}
referencedDialects[constOp].assign({dialect, newDialect});
auto newIt = uniquedConstants.insert({newKey, constOp});
return newIt.first->second;
}