#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "inlining"
using namespace mlir;
static void
remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks,
Location callerLoc) {
DenseMap<Location, Location> mappedLocations;
auto remapOpLoc = [&](Operation *op) {
auto it = mappedLocations.find(op->getLoc());
if (it == mappedLocations.end()) {
auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc);
it = mappedLocations.try_emplace(op->getLoc(), newLoc).first;
}
op->setLoc(it->second);
};
for (auto &block : inlinedBlocks)
block.walk(remapOpLoc);
}
static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
IRMapping &mapper) {
auto remapOperands = [&](Operation *op) {
for (auto &operand : op->getOpOperands())
if (auto mappedOp = mapper.lookupOrNull(operand.get()))
operand.set(mappedOp);
};
for (auto &block : inlinedBlocks)
block.walk(remapOperands);
}
bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const {
if (auto *handler = getInterfaceFor(call))
return handler->isLegalToInline(call, callable, wouldBeCloned);
return false;
}
bool InlinerInterface::isLegalToInline(Region *dest, Region *src,
bool wouldBeCloned,
IRMapping &valueMapping) const {
if (auto *handler = getInterfaceFor(dest->getParentOp()))
return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping);
return false;
}
bool InlinerInterface::isLegalToInline(Operation *op, Region *dest,
bool wouldBeCloned,
IRMapping &valueMapping) const {
if (auto *handler = getInterfaceFor(op))
return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping);
return false;
}
bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const {
auto *handler = getInterfaceFor(op);
return handler ? handler->shouldAnalyzeRecursively(op) : true;
}
void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const {
auto *handler = getInterfaceFor(op);
assert(handler && "expected valid dialect handler");
handler->handleTerminator(op, newDest);
}
void InlinerInterface::handleTerminator(Operation *op,
ValueRange valuesToRepl) const {
auto *handler = getInterfaceFor(op);
assert(handler && "expected valid dialect handler");
handler->handleTerminator(op, valuesToRepl);
}
Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
Operation *callable, Value argument,
DictionaryAttr argumentAttrs) const {
auto *handler = getInterfaceFor(callable);
assert(handler && "expected valid dialect handler");
return handler->handleArgument(builder, call, callable, argument,
argumentAttrs);
}
Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
Operation *callable, Value result,
DictionaryAttr resultAttrs) const {
auto *handler = getInterfaceFor(callable);
assert(handler && "expected valid dialect handler");
return handler->handleResult(builder, call, callable, result, resultAttrs);
}
void InlinerInterface::processInlinedCallBlocks(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
auto *handler = getInterfaceFor(call);
assert(handler && "expected valid dialect handler");
handler->processInlinedCallBlocks(call, inlinedBlocks);
}
static bool isLegalToInline(InlinerInterface &interface, Region *src,
Region *insertRegion, bool shouldCloneInlinedRegion,
IRMapping &valueMapping) {
for (auto &block : *src) {
for (auto &op : block) {
if (!interface.isLegalToInline(&op, insertRegion,
shouldCloneInlinedRegion, valueMapping)) {
LLVM_DEBUG({
llvm::dbgs() << "* Illegal to inline because of op: ";
op.dump();
});
return false;
}
if (interface.shouldAnalyzeRecursively(&op) &&
llvm::any_of(op.getRegions(), [&](Region ®ion) {
return !isLegalToInline(interface, ®ion, insertRegion,
shouldCloneInlinedRegion, valueMapping);
}))
return false;
}
}
return true;
}
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
CallOpInterface call,
CallableOpInterface callable,
IRMapping &mapper) {
SmallVector<DictionaryAttr> argAttrs(
callable.getCallableRegion()->getNumArguments(),
builder.getDictionaryAttr({}));
if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) {
assert(arrayAttr.size() == argAttrs.size());
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
argAttrs[idx] = cast<DictionaryAttr>(attr);
}
for (auto [blockArg, argAttr] :
llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
Value newArgument = interface.handleArgument(
builder, call, callable, mapper.lookup(blockArg), argAttr);
assert(newArgument.getType() == mapper.lookup(blockArg).getType() &&
"expected the argument type to not change");
mapper.map(blockArg, newArgument);
}
}
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
CallOpInterface call, CallableOpInterface callable,
ValueRange results) {
SmallVector<DictionaryAttr> resAttrs(results.size(),
builder.getDictionaryAttr({}));
if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) {
assert(arrayAttr.size() == resAttrs.size());
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
resAttrs[idx] = cast<DictionaryAttr>(attr);
}
SmallVector<DictionaryAttr> resultAttributes;
for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
DenseSet<Operation *> resultUsers;
for (Operation *user : result.getUsers())
resultUsers.insert(user);
Value newResult =
interface.handleResult(builder, call, callable, result, resAttr);
assert(newResult.getType() == result.getType() &&
"expected the result type to not change");
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
return resultUsers.count(operand.getOwner());
});
}
}
static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
if (src->empty())
return failure();
auto *srcEntryBlock = &src->front();
if (llvm::any_of(srcEntryBlock->getArguments(),
[&](BlockArgument arg) { return !mapper.contains(arg); }))
return failure();
Region *insertRegion = inlineBlock->getParent();
if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
mapper) ||
!isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion,
mapper))
return failure();
OpBuilder builder(inlineBlock, inlinePoint);
auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
if (shouldCloneInlinedRegion)
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
else
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
src->getBlocks(), src->begin(),
src->end());
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
postInsertBlock->getIterator());
Block *firstNewBlock = &*newBlocks.begin();
if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc))
remapInlinedLocations(newBlocks, *inlineLoc);
if (!shouldCloneInlinedRegion)
remapInlinedOperands(newBlocks, mapper);
if (call)
interface.processInlinedCallBlocks(call, newBlocks);
interface.processInlinedBlocks(newBlocks);
if (std::next(newBlocks.begin()) == newBlocks.end()) {
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
builder.setInsertionPoint(firstBlockTerminator);
if (call && callable)
handleResultImpl(interface, builder, call, callable,
firstBlockTerminator->getOperands());
interface.handleTerminator(firstBlockTerminator, resultsToReplace);
firstBlockTerminator->erase();
firstNewBlock->getOperations().splice(firstNewBlock->end(),
postInsertBlock->getOperations());
postInsertBlock->erase();
} else {
for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) {
resultToRepl.value().replaceAllUsesWith(
postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()],
resultToRepl.value().getLoc()));
}
builder.setInsertionPointToStart(postInsertBlock);
if (call && callable)
handleResultImpl(interface, builder, call, callable,
postInsertBlock->getArguments());
for (auto &newBlock : newBlocks)
interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);
}
inlineBlock->getOperations().splice(inlineBlock->end(),
firstNewBlock->getOperations());
firstNewBlock->erase();
return success();
}
static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
if (src->empty())
return failure();
auto *entryBlock = &src->front();
if (inlinedOperands.size() != entryBlock->getNumArguments())
return failure();
IRMapping mapper;
for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) {
BlockArgument regionArg = entryBlock->getArgument(i);
if (inlinedOperands[i].getType() != regionArg.getType())
return failure();
mapper.map(regionArg, inlinedOperands[i]);
}
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, resultsToReplace.getTypes(),
inlineLoc, shouldCloneInlinedRegion, call);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), mapper, resultsToReplace,
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace,
TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
resultsToReplace, regionResultTypes, inlineLoc,
shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Operation *inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegion(interface, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
Block *inlineBlock,
Block::iterator inlinePoint,
ValueRange inlinedOperands,
ValueRange resultsToReplace,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion) {
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
inlinedOperands, resultsToReplace, inlineLoc,
shouldCloneInlinedRegion);
}
static Value materializeConversion(const DialectInlinerInterface *interface,
SmallVectorImpl<Operation *> &castOps,
OpBuilder &castBuilder, Value arg, Type type,
Location conversionLoc) {
if (!interface)
return nullptr;
Operation *castOp = interface->materializeCallConversion(castBuilder, arg,
type, conversionLoc);
if (!castOp)
return nullptr;
castOps.push_back(castOp);
assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg &&
castOp->getNumResults() == 1 && *castOp->result_type_begin() == type);
return castOp->getResult(0);
}
LogicalResult mlir::inlineCall(InlinerInterface &interface,
CallOpInterface call,
CallableOpInterface callable, Region *src,
bool shouldCloneInlinedRegion) {
if (src->empty())
return failure();
auto *entryBlock = &src->front();
ArrayRef<Type> callableResultTypes = callable.getResultTypes();
SmallVector<Value, 8> callOperands(call.getArgOperands());
SmallVector<Value, 8> callResults(call->getResults());
if (callOperands.size() != entryBlock->getNumArguments() ||
callResults.size() != callableResultTypes.size())
return failure();
SmallVector<Operation *, 4> castOps;
castOps.reserve(callOperands.size() + callResults.size());
auto cleanupState = [&] {
for (auto *op : castOps) {
op->getResult(0).replaceAllUsesWith(op->getOperand(0));
op->erase();
}
return failure();
};
OpBuilder castBuilder(call);
Location castLoc = call.getLoc();
const auto *callInterface = interface.getInterfaceFor(call->getDialect());
IRMapping mapper;
for (unsigned i = 0, e = callOperands.size(); i != e; ++i) {
BlockArgument regionArg = entryBlock->getArgument(i);
Value operand = callOperands[i];
Type regionArgType = regionArg.getType();
if (operand.getType() != regionArgType) {
if (!(operand = materializeConversion(callInterface, castOps, castBuilder,
operand, regionArgType, castLoc)))
return cleanupState();
}
mapper.map(regionArg, operand);
}
castBuilder.setInsertionPointAfter(call);
for (unsigned i = 0, e = callResults.size(); i != e; ++i) {
Value callResult = callResults[i];
if (callResult.getType() == callableResultTypes[i])
continue;
Value castResult =
materializeConversion(callInterface, castOps, castBuilder, callResult,
callResult.getType(), castLoc);
if (!castResult)
return cleanupState();
callResult.replaceAllUsesWith(castResult);
castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
}
if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion))
return cleanupState();
if (failed(inlineRegionImpl(interface, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
return cleanupState();
return success();
}