#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "sroa"
using namespace mlir;
llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
if (!getOperation()->getBlock()->isEntryBlock())
return {};
return {MemorySlot{getResult(), getElemType()}};
}
Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
}
void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
BlockArgument argument,
OpBuilder &builder) {
for (Operation *user : getOperation()->getUsers())
if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
declareOp.getVarInfo(),
declareOp.getLocationExpr());
}
std::optional<PromotableAllocationOpInterface>
LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
return std::nullopt;
}
SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
if (!mlir::matchPattern(getArraySize(), m_One()))
return {};
auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
if (!destructurable)
return {};
std::optional<DenseMap<Attribute, Type>> destructuredType =
destructurable.getSubelementIndexMap();
if (!destructuredType)
return {};
return {DestructurableMemorySlot{{getResult(), getElemType()},
*destructuredType}};
}
DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
const DestructurableMemorySlot &slot,
const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
assert(slot.ptr == getResult());
builder.setInsertionPointAfter(*this);
auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
DenseMap<Attribute, MemorySlot> slotMap;
for (Attribute index : usedIndices) {
Type elemType = destructurableType.getTypeAtIndex(index);
assert(elemType && "used index must exist");
auto subAlloca = builder.create<LLVM::AllocaOp>(
getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
getArraySize());
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
}
return slotMap;
}
std::optional<DestructurableAllocationOpInterface>
LLVM::AllocaOp::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
assert(slot.ptr == getResult());
this->erase();
return std::nullopt;
}
bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef, const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
static bool isSupportedTypeForConversion(Type type) {
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
return false;
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
return false;
if (auto vectorType = dyn_cast<VectorType>(type))
return !vectorType.isScalable();
return true;
}
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
Type srcType, bool narrowingConversion) {
if (targetType == srcType)
return true;
if (!isSupportedTypeForConversion(targetType) ||
!isSupportedTypeForConversion(srcType))
return false;
uint64_t targetSize = layout.getTypeSize(targetType);
uint64_t srcSize = layout.getTypeSize(srcType);
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
return targetSize == srcSize;
if (narrowingConversion)
return targetSize <= srcSize;
return targetSize >= srcSize;
}
static bool isBigEndian(const DataLayout &dataLayout) {
auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
return endiannessStr && endiannessStr == "big";
}
static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
const DataLayout &dataLayout) {
Type type = val.getType();
assert(isSupportedTypeForConversion(type) &&
"expected value to have a convertible type");
if (isa<IntegerType>(type))
return val;
uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
if (isa<LLVM::LLVMPointerType>(type))
return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
}
static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
Value val, Type targetType) {
assert(isa<IntegerType>(val.getType()) &&
"expected value to have an integer type");
assert(isSupportedTypeForConversion(targetType) &&
"expected the target type to be supported for conversions");
if (val.getType() == targetType)
return val;
if (isa<LLVM::LLVMPointerType>(targetType))
return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
}
static Value castSameSizedTypes(OpBuilder &builder, Location loc,
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
Type srcType = srcValue.getType();
assert(areConversionCompatible(dataLayout, targetType, srcType,
true) &&
"expected that the compatibility was checked before");
if (srcType == targetType)
return srcValue;
if (isa<LLVM::LLVMPointerType>(targetType) &&
isa<LLVM::LLVMPointerType>(srcType))
return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);
Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
return castIntValueToSameSizedType(builder, loc, replacement, targetType);
}
static Value createExtractAndCast(OpBuilder &builder, Location loc,
Value srcValue, Type targetType,
const DataLayout &dataLayout) {
Type srcType = srcValue.getType();
assert(areConversionCompatible(dataLayout, targetType, srcType,
true) &&
"expected that the compatibility was checked before");
uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
if (srcTypeSize == targetTypeSize)
return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
if (isBigEndian(dataLayout)) {
uint64_t shiftAmount = srcTypeSize - targetTypeSize;
auto shiftConstant = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(srcType, shiftAmount));
replacement =
builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
}
replacement = builder.create<LLVM::TruncOp>(
loc, builder.getIntegerType(targetTypeSize), replacement);
return castIntValueToSameSizedType(builder, loc, replacement, targetType);
}
static Value createInsertAndCast(OpBuilder &builder, Location loc,
Value srcValue, Value reachingDef,
const DataLayout &dataLayout) {
assert(areConversionCompatible(dataLayout, reachingDef.getType(),
srcValue.getType(),
false) &&
"expected that the compatibility was checked before");
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
if (slotTypeSize == valueTypeSize)
return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
dataLayout);
Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
Value valueAsInt = castToSameSizedInt(builder, loc, srcValue, dataLayout);
valueAsInt =
builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
if (isBigEndian(dataLayout)) {
Value bigEndianShift = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference));
valueAsInt =
builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
}
APInt maskValue;
if (isBigEndian(dataLayout)) {
maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
} else {
maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
maskValue.flipAllBits();
}
Value mask = builder.create<LLVM::ConstantOp>(
loc, builder.getIntegerAttr(defAsInt.getType(), maskValue));
Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
return castIntValueToSameSizedType(builder, loc, combined,
reachingDef.getType());
}
Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
assert(reachingDef && reachingDef.getType() == slot.elemType &&
"expected the reaching definition's type to match the slot's type");
return createInsertAndCast(builder, getLoc(), getValue(), reachingDef,
dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areConversionCompatible(dataLayout, getResult().getType(),
slot.elemType, true) &&
!getVolatile_();
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition,
getResult().getType(), dataLayout);
getResult().replaceAllUsesWith(newResult);
return DeletionKind::Delete;
}
bool LLVM::StoreOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
Value blockingUse = (*blockingUses.begin())->get();
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
getValue().getType(),
false) &&
!getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
static bool isValidAccessType(const MemorySlot &slot, Type accessType,
const DataLayout &dataLayout) {
return dataLayout.getTypeSize(accessType) <=
dataLayout.getTypeSize(slot.elemType);
}
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
isValidAccessType(slot, getType(), dataLayout));
}
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(getAddr() != slot.ptr ||
isValidAccessType(slot, getValue().getType(), dataLayout));
}
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
Attribute index) {
auto subelementIndexMap =
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
if (!subelementIndexMap)
return {};
assert(!subelementIndexMap->empty());
return subelementIndexMap->lookup(index);
}
bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getVolatile_())
return false;
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
Type subslotType = getTypeAtIndex(slot, index);
if (!subslotType)
return false;
if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
return false;
usedIndices.insert(index);
return true;
}
DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
auto it = subslots.find(index);
assert(it != subslots.end());
getAddrMutable().set(it->getSecond().ptr);
return DeletionKind::Keep;
}
bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getVolatile_())
return false;
if (getValue() == slot.ptr)
return false;
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
Type subslotType = getTypeAtIndex(slot, index);
if (!subslotType)
return false;
if (dataLayout.getTypeSize(getValue().getType()) >
dataLayout.getTypeSize(subslotType))
return false;
usedIndices.insert(index);
return true;
}
DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
auto it = subslots.find(index);
assert(it != subslots.end());
getAddrMutable().set(it->getSecond().ptr);
return DeletionKind::Keep;
}
static bool forwardToUsers(Operation *op,
SmallVectorImpl<OpOperand *> &newBlockingUses) {
for (Value result : op->getResults())
for (OpOperand &use : result.getUses())
newBlockingUses.push_back(&use);
return true;
}
bool LLVM::BitcastOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return forwardToUsers(*this, newBlockingUses);
}
DeletionKind LLVM::BitcastOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return forwardToUsers(*this, newBlockingUses);
}
DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::LifetimeStartOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return true;
}
DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::LifetimeEndOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return true;
}
DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::InvariantStartOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return true;
}
DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::InvariantEndOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return true;
}
DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::DbgDeclareOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return true;
}
DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
bool LLVM::DbgValueOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (blockingUses.size() != 1)
return false;
return (*blockingUses.begin())->get() == getValue();
}
DeletionKind LLVM::DbgValueOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
builder.setInsertionPoint(*this);
auto undef =
builder.create<UndefOp>(getValue().getLoc(), getValue().getType());
getValueMutable().assign(undef);
return DeletionKind::Keep;
}
bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
void LLVM::DbgDeclareOp::visitReplacedValues(
ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) {
for (auto [op, value] : definitions) {
builder.setInsertionPointAfter(op);
builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
getLocationExpr());
}
}
static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
return llvm::all_of(gepOp.getIndices(), [](auto index) {
auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
return indexAttr && indexAttr.getValue() == 0;
});
}
bool LLVM::GEPOp::canUsesBeRemoved(
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (!hasAllZeroIndices(*this))
return false;
return forwardToUsers(*this, newBlockingUses);
}
DeletionKind LLVM::GEPOp::removeBlockingUses(
const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
return DeletionKind::Delete;
}
static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
LLVM::GEPOp gep) {
SmallVector<uint64_t> indices;
for (auto index : gep.getIndices()) {
auto constIndex = dyn_cast<IntegerAttr>(index);
if (!constIndex)
return {};
int64_t gepIndex = constIndex.getInt();
if (gepIndex < 0)
return {};
indices.push_back(gepIndex);
}
Type currentType = gep.getElemType();
uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
for (uint64_t index : llvm::drop_begin(indices)) {
bool shouldCancel =
TypeSwitch<Type, bool>(currentType)
.Case([&](LLVM::LLVMArrayType arrayType) {
offset +=
index * dataLayout.getTypeSize(arrayType.getElementType());
currentType = arrayType.getElementType();
return false;
})
.Case([&](LLVM::LLVMStructType structType) {
ArrayRef<Type> body = structType.getBody();
assert(index < body.size() && "expected valid struct indexing");
for (uint32_t i : llvm::seq(index)) {
if (!structType.isPacked())
offset = llvm::alignTo(
offset, dataLayout.getTypeABIAlignment(body[i]));
offset += dataLayout.getTypeSize(body[i]);
}
if (!structType.isPacked())
offset = llvm::alignTo(
offset, dataLayout.getTypeABIAlignment(body[index]));
currentType = body[index];
return false;
})
.Default([&](Type type) {
LLVM_DEBUG(llvm::dbgs()
<< "[sroa] Unsupported type for offset computations"
<< type << "\n");
return true;
});
if (shouldCancel)
return std::nullopt;
}
return offset;
}
namespace {
struct SubslotAccessInfo {
uint32_t index;
uint64_t subslotOffset;
};
}
static std::optional<SubslotAccessInfo>
getSubslotAccessInfo(const DestructurableMemorySlot &slot,
const DataLayout &dataLayout, LLVM::GEPOp gep) {
std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
if (!offset)
return {};
auto isOutOfBoundsGEPIndex = [](uint64_t index) {
return index >= (1 << LLVM::kGEPConstantBitWidth);
};
Type type = slot.elemType;
if (*offset >= dataLayout.getTypeSize(type))
return {};
return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
.Case([&](LLVM::LLVMArrayType arrayType)
-> std::optional<SubslotAccessInfo> {
uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
uint64_t index = *offset / elemSize;
if (isOutOfBoundsGEPIndex(index))
return {};
return SubslotAccessInfo{static_cast<uint32_t>(index),
*offset - (index * elemSize)};
})
.Case([&](LLVM::LLVMStructType structType)
-> std::optional<SubslotAccessInfo> {
uint64_t distanceToStart = 0;
for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
uint64_t elemSize = dataLayout.getTypeSize(elem);
if (!structType.isPacked()) {
distanceToStart = llvm::alignTo(
distanceToStart, dataLayout.getTypeABIAlignment(elem));
if (offset < distanceToStart)
return {};
}
if (offset < distanceToStart + elemSize) {
if (isOutOfBoundsGEPIndex(index))
return {};
return SubslotAccessInfo{static_cast<uint32_t>(index),
*offset - distanceToStart};
}
distanceToStart += elemSize;
}
return {};
});
}
static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
unsigned size) {
auto byteType = IntegerType::get(context, 8);
return LLVM::LLVMArrayType::get(context, byteType, size);
}
LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (getBase() != slot.ptr)
return success();
std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
if (!gepOffset)
return failure();
uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
if (*gepOffset >= slotSize)
return failure();
mustBeSafelyUsed.emplace_back<MemorySlot>(
{getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
return success();
}
bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
return false;
if (getBase() != slot.ptr)
return false;
std::optional<SubslotAccessInfo> accessInfo =
getSubslotAccessInfo(slot, dataLayout, *this);
if (!accessInfo)
return false;
auto indexAttr =
IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
assert(slot.subelementTypes.contains(indexAttr));
usedIndices.insert(indexAttr);
Type subslotType = slot.subelementTypes.lookup(indexAttr);
uint64_t slotSize = dataLayout.getTypeSize(subslotType);
LLVM::LLVMArrayType remainingSlotType =
getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
return true;
}
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
std::optional<SubslotAccessInfo> accessInfo =
getSubslotAccessInfo(slot, dataLayout, *this);
assert(accessInfo && "expected access info to be checked before");
auto indexAttr =
IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
const MemorySlot &newSlot = subslots.at(indexAttr);
auto byteType = IntegerType::get(builder.getContext(), 8);
auto newPtr = builder.createOrFold<LLVM::GEPOp>(
getLoc(), getResult().getType(), byteType, newSlot.ptr,
ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
getResult().replaceAllUsesWith(newPtr);
return DeletionKind::Delete;
}
namespace {
template <class MemIntr>
std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
APInt memIntrLen;
if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
return {};
if (memIntrLen.getBitWidth() > 64)
return {};
return memIntrLen.getZExtValue();
}
template <>
std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
APInt memIntrLen = op.getLen();
if (memIntrLen.getBitWidth() > 64)
return {};
return memIntrLen.getZExtValue();
}
}
template <class MemIntr>
static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
const DataLayout &dataLayout) {
if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
op.getDst() != slot.ptr)
return false;
std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(slot.elemType);
}
static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
return llvm::all_of(llvm::make_first_range(slot.subelementTypes),
[&](Attribute index) {
auto intIndex = dyn_cast<IntegerAttr>(index);
return intIndex && intIndex.getType() == i32;
});
}
bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
if (intType.getWidth() == 8)
return getVal();
assert(intType.getWidth() % 8 == 0);
uint64_t coveredBits = 8;
Value currentValue =
builder.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
while (coveredBits < intType.getWidth()) {
Value shiftBy =
builder.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
Value shifted =
builder.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
currentValue =
builder.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
coveredBits *= 2;
}
return currentValue;
})
.Default([](Type) -> Value {
llvm_unreachable(
"getStored should not be called on memset to unsupported type");
});
}
bool LLVM::MemsetOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
bool canConvertType =
TypeSwitch<Type, bool>(slot.elemType)
.Case([](IntegerType intType) {
return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
})
.Default([](Type) { return false; });
if (!canConvertType)
return false;
if (getIsVolatile())
return false;
return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
}
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
}
bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (&slot.elemType.getDialect() != getOperation()->getDialect())
return false;
if (getIsVolatile())
return false;
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))
return false;
return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
}
DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
IntegerAttr memsetLenAttr;
bool successfulMatch =
matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
(void)successfulMatch;
assert(successfulMatch);
bool packed = false;
if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
packed = structType.isPacked();
Type i32 = IntegerType::get(getContext(), 32);
uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
uint64_t covered = 0;
for (size_t i = 0; i < types->size(); i++) {
Attribute index = IntegerAttr::get(i32, i);
Type elemType = types->at(index);
uint64_t typeSize = dataLayout.getTypeSize(elemType);
if (!packed)
covered =
llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
if (covered >= memsetLen)
break;
if (subslots.contains(index)) {
uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
Value newMemsetSizeValue =
builder
.create<LLVM::ConstantOp>(
getLen().getLoc(),
IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
.getResult();
builder.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, getVal(),
newMemsetSizeValue, getIsVolatile());
}
covered += typeSize;
}
return DeletionKind::Delete;
}
template <class MemcpyLike>
static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
return op.getSrc() == slot.ptr;
}
template <class MemcpyLike>
static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
return op.getDst() == slot.ptr;
}
template <class MemcpyLike>
static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
OpBuilder &builder) {
return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
}
template <class MemcpyLike>
static bool
memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
if (op.getDst() == op.getSrc())
return false;
if (op.getIsVolatile())
return false;
return getStaticMemIntrLen(op) == dataLayout.getTypeSize(slot.elemType);
}
template <class MemcpyLike>
static DeletionKind
memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition) {
if (op.loadsFrom(slot))
builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst());
return DeletionKind::Delete;
}
template <class MemcpyLike>
static LogicalResult
memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
DataLayout dataLayout = DataLayout::closest(op);
return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
}
template <class MemcpyLike>
static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
if (op.getIsVolatile())
return false;
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))
return false;
if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(slot.elemType))
return false;
if (op.getSrc() == slot.ptr)
for (Attribute index : llvm::make_first_range(slot.subelementTypes))
usedIndices.insert(index);
return true;
}
namespace {
template <class MemcpyLike>
void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
MemcpyLike toReplace, Value dst, Value src,
Type toCpy, bool isVolatile) {
Value memcpySize = builder.create<LLVM::ConstantOp>(
toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
layout.getTypeSize(toCpy)));
builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
isVolatile);
}
template <>
void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
LLVM::MemcpyInlineOp toReplace, Value dst,
Value src, Type toCpy, bool isVolatile) {
Type lenType = IntegerType::get(toReplace->getContext(),
toReplace.getLen().getBitWidth());
builder.create<LLVM::MemcpyInlineOp>(
toReplace.getLoc(), dst, src,
IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
}
}
template <class MemcpyLike>
static DeletionKind
memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
const DataLayout &dataLayout) {
if (subslots.empty())
return DeletionKind::Delete;
assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
bool isDst = slot.ptr == op.getDst();
#ifndef NDEBUG
size_t slotsTreated = 0;
#endif
Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
for (size_t i = 0, e = slot.subelementTypes.size(); i != e; i++) {
Attribute index = IntegerAttr::get(indexType, i);
if (!subslots.contains(index))
continue;
const MemorySlot &subslot = subslots.at(index);
#ifndef NDEBUG
slotsTreated++;
#endif
SmallVector<LLVM::GEPArg> gepIndices{
0, static_cast<int32_t>(
cast<IntegerAttr>(index).getValue().getZExtValue())};
Value subslotPtrInOther = builder.create<LLVM::GEPOp>(
op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType,
isDst ? op.getSrc() : op.getDst(), gepIndices);
createMemcpyLikeToReplace(builder, dataLayout, op,
isDst ? subslot.ptr : subslotPtrInOther,
isDst ? subslotPtrInOther : subslot.ptr,
subslot.elemType, op.getIsVolatile());
}
assert(subslots.size() == slotsTreated);
return DeletionKind::Delete;
}
bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
return memcpyLoadsFrom(*this, slot);
}
bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, builder);
}
bool LLVM::MemcpyOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
dataLayout);
}
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
reachingDefinition);
}
LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
dataLayout);
}
DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
return memcpyRewire(*this, slot, subslots, builder, dataLayout);
}
bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
return memcpyLoadsFrom(*this, slot);
}
bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
OpBuilder &builder, Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, builder);
}
bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
dataLayout);
}
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
reachingDefinition);
}
LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemcpyInlineOp::canRewire(
const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
dataLayout);
}
DeletionKind
LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder, const DataLayout &dataLayout) {
return memcpyRewire(*this, slot, subslots, builder, dataLayout);
}
bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
return memcpyLoadsFrom(*this, slot);
}
bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder,
Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, builder);
}
bool LLVM::MemmoveOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
const DataLayout &dataLayout) {
return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
dataLayout);
}
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
OpBuilder &builder, Value reachingDefinition,
const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
reachingDefinition);
}
LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
}
bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
SmallPtrSetImpl<Attribute> &usedIndices,
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
const DataLayout &dataLayout) {
return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
dataLayout);
}
DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
DenseMap<Attribute, MemorySlot> &subslots,
OpBuilder &builder,
const DataLayout &dataLayout) {
return memcpyRewire(*this, slot, subslots, builder, dataLayout);
}
std::optional<DenseMap<Attribute, Type>>
LLVM::LLVMStructType::getSubelementIndexMap() {
Type i32 = IntegerType::get(getContext(), 32);
DenseMap<Attribute, Type> destructured;
for (const auto &[index, elemType] : llvm::enumerate(getBody()))
destructured.insert({IntegerAttr::get(i32, index), elemType});
return destructured;
}
Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
ArrayRef<Type> body = getBody();
if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
return {};
return body[indexInt];
}
std::optional<DenseMap<Attribute, Type>>
LLVM::LLVMArrayType::getSubelementIndexMap() const {
constexpr size_t maxArraySizeForDestructuring = 16;
if (getNumElements() > maxArraySizeForDestructuring)
return {};
int32_t numElements = getNumElements();
Type i32 = IntegerType::get(getContext(), 32);
DenseMap<Attribute, Type> destructured;
for (int32_t index = 0; index < numElements; ++index)
destructured.insert({IntegerAttr::get(i32, index), getElementType()});
return destructured;
}
Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
if (!indexAttr || !indexAttr.getType().isInteger(32))
return {};
int32_t indexInt = indexAttr.getInt();
if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
return {};
return getElementType();
}