#include "LLVMInlining.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "llvm-inliner"
using namespace mlir;
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) {
SmallVector<Operation *> stack(allocaOp->getUsers().begin(),
allocaOp->getUsers().end());
while (!stack.empty()) {
Operation *op = stack.pop_back_val();
if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
return true;
if (isa<LLVM::BitcastOp>(op))
stack.append(op->getUsers().begin(), op->getUsers().end());
}
return false;
}
static void
handleInlinedAllocas(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
Block *callerEntryBlock = nullptr;
Operation *currentOp = call;
while (Operation *parentOp = currentOp->getParentOp()) {
if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
parentOp->mightHaveTrait<OpTrait::AutomaticAllocationScope>()) {
callerEntryBlock = ¤tOp->getParentRegion()->front();
break;
}
currentOp = parentOp;
}
Block *calleeEntryBlock = &(*inlinedBlocks.begin());
if (!callerEntryBlock || callerEntryBlock == calleeEntryBlock)
return;
SmallVector<std::tuple<LLVM::AllocaOp, IntegerAttr, bool>> allocasToMove;
bool shouldInsertLifetimes = false;
bool hasDynamicAlloca = false;
for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) {
IntegerAttr arraySize;
if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) {
hasDynamicAlloca = true;
continue;
}
bool shouldInsertLifetime =
arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp);
shouldInsertLifetimes |= shouldInsertLifetime;
allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
}
for (Block &block : llvm::drop_begin(inlinedBlocks)) {
if (hasDynamicAlloca)
break;
hasDynamicAlloca =
llvm::any_of(block.getOps<LLVM::AllocaOp>(), [](auto allocaOp) {
return !matchPattern(allocaOp.getArraySize(), m_Constant());
});
}
if (allocasToMove.empty() && !hasDynamicAlloca)
return;
OpBuilder builder(calleeEntryBlock, calleeEntryBlock->begin());
Value stackPtr;
if (hasDynamicAlloca) {
stackPtr = builder.create<LLVM::StackSaveOp>(
call->getLoc(), LLVM::LLVMPointerType::get(call->getContext()));
}
builder.setInsertionPoint(callerEntryBlock, callerEntryBlock->begin());
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
auto newConstant = builder.create<LLVM::ConstantOp>(
allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
if (shouldInsertLifetime) {
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(allocaOp);
builder.create<LLVM::LifetimeStartOp>(
allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
allocaOp->moveAfter(newConstant);
allocaOp.getArraySizeMutable().assign(newConstant.getResult());
}
if (!shouldInsertLifetimes && !hasDynamicAlloca)
return;
for (Block &block : inlinedBlocks) {
if (!block.getTerminator()->hasTrait<OpTrait::ReturnLike>())
continue;
builder.setInsertionPoint(block.getTerminator());
if (hasDynamicAlloca)
builder.create<LLVM::StackRestoreOp>(call->getLoc(), stackPtr);
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
if (shouldInsertLifetime)
builder.create<LLVM::LifetimeEndOp>(
allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
}
}
static void
deepCloneAliasScopes(iterator_range<Region::iterator> inlinedBlocks) {
DenseMap<Attribute, Attribute> mapping;
AttrTypeWalker walker;
walker.addWalk([&](LLVM::AliasScopeDomainAttr domainAttr) {
mapping[domainAttr] = LLVM::AliasScopeDomainAttr::get(
domainAttr.getContext(), domainAttr.getDescription());
});
walker.addWalk([&](LLVM::AliasScopeAttr scopeAttr) {
mapping[scopeAttr] = LLVM::AliasScopeAttr::get(
cast<LLVM::AliasScopeDomainAttr>(mapping.lookup(scopeAttr.getDomain())),
scopeAttr.getDescription());
});
auto convertScopeList = [&](ArrayAttr arrayAttr) -> ArrayAttr {
if (!arrayAttr)
return nullptr;
walker.walk(arrayAttr);
return ArrayAttr::get(arrayAttr.getContext(),
llvm::map_to_vector(arrayAttr, [&](Attribute attr) {
return mapping.lookup(attr);
}));
};
for (Block &block : inlinedBlocks) {
block.walk([&](Operation *op) {
if (auto aliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(op)) {
aliasInterface.setAliasScopes(
convertScopeList(aliasInterface.getAliasScopesOrNull()));
aliasInterface.setNoAliasScopes(
convertScopeList(aliasInterface.getNoAliasScopesOrNull()));
}
if (auto noAliasScope = dyn_cast<LLVM::NoAliasScopeDeclOp>(op)) {
walker.walk(noAliasScope.getScopeAttr());
noAliasScope.setScopeAttr(cast<LLVM::AliasScopeAttr>(
mapping.lookup(noAliasScope.getScopeAttr())));
}
});
}
}
static ArrayAttr concatArrayAttr(ArrayAttr lhs, ArrayAttr rhs) {
if (!lhs)
return rhs;
if (!rhs)
return lhs;
SmallVector<Attribute> result;
llvm::append_range(result, lhs);
llvm::append_range(result, rhs);
return ArrayAttr::get(lhs.getContext(), result);
}
static Value getUnderlyingObject(Value pointerValue) {
while (true) {
if (auto gepOp = pointerValue.getDefiningOp<LLVM::GEPOp>()) {
pointerValue = gepOp.getBase();
continue;
}
if (auto addrCast = pointerValue.getDefiningOp<LLVM::AddrSpaceCastOp>()) {
pointerValue = addrCast.getOperand();
continue;
}
break;
}
return pointerValue;
}
static SmallVector<Value> getUnderlyingObjectSet(Value pointerValue) {
SmallVector<Value> result;
SmallVector<Value> workList{pointerValue};
SmallPtrSet<Value, 4> seen;
do {
Value current = workList.pop_back_val();
current = getUnderlyingObject(current);
if (!seen.insert(current).second)
continue;
if (auto selectOp = current.getDefiningOp<LLVM::SelectOp>()) {
workList.push_back(selectOp.getTrueValue());
workList.push_back(selectOp.getFalseValue());
continue;
}
if (auto blockArg = dyn_cast<BlockArgument>(current)) {
Block *parentBlock = blockArg.getParentBlock();
SmallVector<Value> operands;
bool anyUnknown = false;
for (auto iter = parentBlock->pred_begin();
iter != parentBlock->pred_end(); iter++) {
auto branch = dyn_cast<BranchOpInterface>((*iter)->getTerminator());
if (!branch) {
result.push_back(blockArg);
anyUnknown = true;
break;
}
Value operand = branch.getSuccessorOperands(
iter.getSuccessorIndex())[blockArg.getArgNumber()];
if (!operand) {
result.push_back(blockArg);
anyUnknown = true;
break;
}
operands.push_back(operand);
}
if (!anyUnknown)
llvm::append_range(workList, operands);
continue;
}
result.push_back(current);
} while (!workList.empty());
return result;
}
static void createNewAliasScopesFromNoAliasParameter(
Operation *call, iterator_range<Region::iterator> inlinedBlocks) {
SetVector<LLVM::SSACopyOp> noAliasParams;
for (Value argument : cast<LLVM::CallOp>(call).getArgOperands()) {
for (Operation *user : argument.getUsers()) {
auto ssaCopy = llvm::dyn_cast<LLVM::SSACopyOp>(user);
if (!ssaCopy)
continue;
if (!ssaCopy->hasAttr(LLVM::LLVMDialect::getNoAliasAttrName()))
continue;
noAliasParams.insert(ssaCopy);
}
}
if (noAliasParams.empty())
return;
auto exit = llvm::make_scope_exit([&] {
for (LLVM::SSACopyOp ssaCopyOp : noAliasParams) {
ssaCopyOp.replaceAllUsesWith(ssaCopyOp.getOperand());
ssaCopyOp->erase();
}
});
auto functionDomain = LLVM::AliasScopeDomainAttr::get(
call->getContext(), cast<LLVM::CallOp>(call).getCalleeAttr().getAttr());
DenseMap<Value, LLVM::AliasScopeAttr> pointerScopes;
for (LLVM::SSACopyOp copyOp : noAliasParams) {
auto scope = LLVM::AliasScopeAttr::get(functionDomain);
pointerScopes[copyOp] = scope;
OpBuilder(call).create<LLVM::NoAliasScopeDeclOp>(call->getLoc(), scope);
}
for (Block &inlinedBlock : inlinedBlocks) {
inlinedBlock.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
SmallVector<Value> pointerArgs = aliasInterface.getAccessedOperands();
SmallPtrSet<Value, 4> basedOnPointers;
for (Value pointer : pointerArgs)
llvm::copy(getUnderlyingObjectSet(pointer),
std::inserter(basedOnPointers, basedOnPointers.begin()));
bool aliasesOtherKnownObject = false;
if (llvm::any_of(basedOnPointers, [&](Value object) {
if (matchPattern(object, m_Constant()))
return false;
if (noAliasParams.contains(object.getDefiningOp<LLVM::SSACopyOp>()))
return false;
if (isa_and_nonnull<LLVM::AllocaOp, LLVM::AddressOfOp>(
object.getDefiningOp())) {
aliasesOtherKnownObject = true;
return false;
}
return true;
}))
return;
SmallVector<Attribute> noAliasScopes;
for (LLVM::SSACopyOp noAlias : noAliasParams) {
if (basedOnPointers.contains(noAlias))
continue;
noAliasScopes.push_back(pointerScopes[noAlias]);
}
if (!noAliasScopes.empty())
aliasInterface.setNoAliasScopes(
concatArrayAttr(aliasInterface.getNoAliasScopesOrNull(),
ArrayAttr::get(call->getContext(), noAliasScopes)));
if (aliasesOtherKnownObject ||
isa<LLVM::CallOp>(aliasInterface.getOperation()))
return;
SmallVector<Attribute> aliasScopes;
for (LLVM::SSACopyOp noAlias : noAliasParams)
if (basedOnPointers.contains(noAlias))
aliasScopes.push_back(pointerScopes[noAlias]);
if (!aliasScopes.empty())
aliasInterface.setAliasScopes(
concatArrayAttr(aliasInterface.getAliasScopesOrNull(),
ArrayAttr::get(call->getContext(), aliasScopes)));
});
}
}
static void
appendCallOpAliasScopes(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
auto callAliasInterface = dyn_cast<LLVM::AliasAnalysisOpInterface>(call);
if (!callAliasInterface)
return;
ArrayAttr aliasScopes = callAliasInterface.getAliasScopesOrNull();
ArrayAttr noAliasScopes = callAliasInterface.getNoAliasScopesOrNull();
if (!aliasScopes && !noAliasScopes)
return;
for (Block &block : inlinedBlocks) {
block.walk([&](LLVM::AliasAnalysisOpInterface aliasInterface) {
if (aliasScopes)
aliasInterface.setAliasScopes(concatArrayAttr(
aliasInterface.getAliasScopesOrNull(), aliasScopes));
if (noAliasScopes)
aliasInterface.setNoAliasScopes(concatArrayAttr(
aliasInterface.getNoAliasScopesOrNull(), noAliasScopes));
});
}
}
static void handleAliasScopes(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
deepCloneAliasScopes(inlinedBlocks);
createNewAliasScopesFromNoAliasParameter(call, inlinedBlocks);
appendCallOpAliasScopes(call, inlinedBlocks);
}
static void handleAccessGroups(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
auto callAccessGroupInterface = dyn_cast<LLVM::AccessGroupOpInterface>(call);
if (!callAccessGroupInterface)
return;
auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull();
if (!accessGroups)
return;
for (Block &block : inlinedBlocks)
for (auto accessGroupOpInterface :
block.getOps<LLVM::AccessGroupOpInterface>())
accessGroupOpInterface.setAccessGroups(concatArrayAttr(
accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups));
}
static void
handleLoopAnnotations(Operation *call,
iterator_range<Region::iterator> inlinedBlocks) {
auto func = call->getParentOfType<FunctionOpInterface>();
if (!func)
return;
LocationAttr funcLoc = func->getLoc();
auto fusedLoc = dyn_cast_if_present<FusedLoc>(funcLoc);
if (!fusedLoc)
return;
auto scope =
dyn_cast_if_present<LLVM::DISubprogramAttr>(fusedLoc.getMetadata());
if (!scope)
return;
auto updateLoc = [&](FusedLoc loc) -> FusedLoc {
if (!loc)
return {};
Location callSiteLoc = CallSiteLoc::get(loc, call->getLoc());
return FusedLoc::get(loc.getContext(), callSiteLoc, scope);
};
AttrTypeReplacer replacer;
replacer.addReplacement([&](LLVM::LoopAnnotationAttr loopAnnotation)
-> std::pair<Attribute, WalkResult> {
FusedLoc newStartLoc = updateLoc(loopAnnotation.getStartLoc());
FusedLoc newEndLoc = updateLoc(loopAnnotation.getEndLoc());
if (!newStartLoc && !newEndLoc)
return {loopAnnotation, WalkResult::advance()};
auto newLoopAnnotation = LLVM::LoopAnnotationAttr::get(
loopAnnotation.getContext(), loopAnnotation.getDisableNonforced(),
loopAnnotation.getVectorize(), loopAnnotation.getInterleave(),
loopAnnotation.getUnroll(), loopAnnotation.getUnrollAndJam(),
loopAnnotation.getLicm(), loopAnnotation.getDistribute(),
loopAnnotation.getPipeline(), loopAnnotation.getPeeled(),
loopAnnotation.getUnswitch(), loopAnnotation.getMustProgress(),
loopAnnotation.getIsVectorized(), newStartLoc, newEndLoc,
loopAnnotation.getParallelAccesses());
return {newLoopAnnotation, WalkResult::advance()};
});
for (Block &block : inlinedBlocks)
for (Operation &op : block)
replacer.recursivelyReplaceElementsIn(&op);
}
static uint64_t tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca,
uint64_t requestedAlignment,
DataLayout const &dataLayout) {
uint64_t allocaAlignment = alloca.getAlignment().value_or(1);
if (requestedAlignment <= allocaAlignment)
return allocaAlignment;
uint64_t naturalStackAlignmentBits = dataLayout.getStackAlignment();
if (naturalStackAlignmentBits == 0 ||
8 * requestedAlignment <= naturalStackAlignmentBits ||
8 * allocaAlignment > naturalStackAlignmentBits) {
alloca.setAlignment(requestedAlignment);
allocaAlignment = requestedAlignment;
}
return allocaAlignment;
}
static uint64_t tryToEnforceAlignment(Value value, uint64_t requestedAlignment,
DataLayout const &dataLayout) {
if (Operation *definingOp = value.getDefiningOp()) {
if (auto alloca = dyn_cast<LLVM::AllocaOp>(definingOp))
return tryToEnforceAllocaAlignment(alloca, requestedAlignment,
dataLayout);
if (auto addressOf = dyn_cast<LLVM::AddressOfOp>(definingOp))
if (auto global = SymbolTable::lookupNearestSymbolFrom<LLVM::GlobalOp>(
definingOp, addressOf.getGlobalNameAttr()))
return global.getAlignment().value_or(1);
return 1;
}
Operation *parentOp = value.getParentBlock()->getParentOp();
if (auto func = dyn_cast<LLVM::LLVMFuncOp>(parentOp)) {
auto blockArg = llvm::cast<BlockArgument>(value);
if (Attribute alignAttr = func.getArgAttr(
blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName()))
return cast<IntegerAttr>(alignAttr).getValue().getLimitedValue();
}
return 1;
}
static Value handleByValArgumentInit(OpBuilder &builder, Location loc,
Value argument, Type elementType,
uint64_t elementTypeSize,
uint64_t targetAlignment) {
Value allocaOp;
{
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = &(*argument.getParentRegion()->begin());
builder.setInsertionPointToStart(entryBlock);
Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getI64IntegerAttr(1));
allocaOp = builder.create<LLVM::AllocaOp>(
loc, argument.getType(), elementType, one, targetAlignment);
}
Value copySize = builder.create<LLVM::ConstantOp>(
loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize));
builder.create<LLVM::MemcpyOp>(loc, allocaOp, argument, copySize,
false);
return allocaOp;
}
static Value handleByValArgument(OpBuilder &builder, Operation *callable,
Value argument, Type elementType,
uint64_t requestedAlignment) {
auto func = cast<LLVM::LLVMFuncOp>(callable);
LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr();
bool isReadOnly = memoryEffects &&
memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef &&
memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod;
DataLayout dataLayout = DataLayout::closest(callable);
uint64_t minimumAlignment = dataLayout.getTypeABIAlignment(elementType);
if (isReadOnly) {
if (requestedAlignment <= minimumAlignment)
return argument;
uint64_t currentAlignment =
tryToEnforceAlignment(argument, requestedAlignment, dataLayout);
if (currentAlignment >= requestedAlignment)
return argument;
}
uint64_t targetAlignment = std::max(requestedAlignment, minimumAlignment);
return handleByValArgumentInit(builder, func.getLoc(), argument, elementType,
dataLayout.getTypeSize(elementType),
targetAlignment);
}
namespace {
struct LLVMInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
LLVMInlinerInterface(Dialect *dialect)
: DialectInlinerInterface(dialect),
disallowedFunctionAttrs({
StringAttr::get(dialect->getContext(), "noduplicate"),
StringAttr::get(dialect->getContext(), "presplitcoroutine"),
StringAttr::get(dialect->getContext(), "returns_twice"),
StringAttr::get(dialect->getContext(), "strictfp"),
}) {}
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
if (!wouldBeCloned)
return false;
if (!isa<LLVM::CallOp>(call)) {
LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
<< LLVM::CallOp::getOperationName() << "' op\n");
return false;
}
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!funcOp) {
LLVM_DEBUG(llvm::dbgs()
<< "Cannot inline: callable is not an '"
<< LLVM::LLVMFuncOp::getOperationName() << "' op\n");
return false;
}
if (funcOp.isNoInline()) {
LLVM_DEBUG(llvm::dbgs()
<< "Cannot inline: function is marked no_inline\n");
return false;
}
if (funcOp.isVarArg()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
return false;
}
if (auto attrs = funcOp.getArgAttrs()) {
for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
<< ": inalloca arguments not supported\n");
return false;
}
}
}
if (funcOp.getPersonality()) {
LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
<< ": unhandled function personality\n");
return false;
}
if (funcOp.getPassthrough()) {
if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) {
auto stringAttr = dyn_cast<StringAttr>(attr);
if (!stringAttr)
return false;
if (disallowedFunctionAttrs.contains(stringAttr)) {
LLVM_DEBUG(llvm::dbgs()
<< "Cannot inline " << funcOp.getSymName()
<< ": found disallowed function attribute "
<< stringAttr << "\n");
return true;
}
return false;
}))
return false;
}
return true;
}
bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
return true;
}
bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
return !isa<LLVM::VaStartOp>(op);
}
void handleTerminator(Operation *op, Block *newDest) const final {
auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
if (!returnOp)
return;
OpBuilder builder(op);
builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
op->erase();
}
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
auto returnOp = cast<LLVM::ReturnOp>(op);
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (auto [dst, src] : llvm::zip(valuesToRepl, returnOp.getOperands()))
dst.replaceAllUsesWith(src);
}
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
Value argument,
DictionaryAttr argumentAttrs) const final {
if (std::optional<NamedAttribute> attr =
argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) {
Type elementType = cast<TypeAttr>(attr->getValue()).getValue();
uint64_t requestedAlignment = 1;
if (std::optional<NamedAttribute> alignAttr =
argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) {
requestedAlignment = cast<IntegerAttr>(alignAttr->getValue())
.getValue()
.getLimitedValue();
}
return handleByValArgument(builder, callable, argument, elementType,
requestedAlignment);
}
if (argumentAttrs.contains(LLVM::LLVMDialect::getNoAliasAttrName())) {
if (argument.use_empty())
return argument;
auto copyOp = builder.create<LLVM::SSACopyOp>(call->getLoc(), argument);
copyOp->setDiscardableAttr(
builder.getStringAttr(LLVM::LLVMDialect::getNoAliasAttrName()),
builder.getUnitAttr());
return copyOp;
}
return argument;
}
void processInlinedCallBlocks(
Operation *call,
iterator_range<Region::iterator> inlinedBlocks) const override {
handleInlinedAllocas(call, inlinedBlocks);
handleAliasScopes(call, inlinedBlocks);
handleAccessGroups(call, inlinedBlocks);
handleLoopAnnotations(call, inlinedBlocks);
}
const DenseSet<StringAttr> disallowedFunctionAttrs;
};
}
void LLVM::detail::addLLVMInlinerInterface(LLVM::LLVMDialect *dialect) {
dialect->addInterfaces<LLVMInlinerInterface>();
}