#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include <optional>
using namespace mlir;
using namespace mlir::bufferization;
Operation *BufferPlacementAllocs::getStartOperation(Value allocValue,
Block *placementBlock,
const Liveness &liveness) {
const LivenessBlockInfo &livenessInfo = *liveness.getLiveness(placementBlock);
Operation *startOperation = livenessInfo.getStartOperation(allocValue);
if (startOperation->getBlock() != placementBlock) {
Operation *opInPlacementBlock =
placementBlock->findAncestorOpInBlock(*startOperation);
startOperation = opInPlacementBlock ? opInPlacementBlock
: placementBlock->getTerminator();
}
return startOperation;
}
BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); }
void BufferPlacementAllocs::build(Operation *op) {
op->walk([&](MemoryEffectOpInterface opInterface) {
SmallVector<MemoryEffects::EffectInstance, 2> effects;
opInterface.getEffects(effects);
SmallVector<MemoryEffects::EffectInstance, 2> allocateResultEffects;
llvm::copy_if(
effects, std::back_inserter(allocateResultEffects),
[=](MemoryEffects::EffectInstance &it) {
Value value = it.getValue();
return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
isa<OpResult>(value) &&
it.getResource() !=
SideEffects::AutomaticAllocationScopeResource::get();
});
if (allocateResultEffects.size() != 1)
return;
Value allocValue = allocateResultEffects[0].getValue();
std::optional<Operation *> dealloc = memref::findDealloc(allocValue);
if (!dealloc)
return;
allocs.push_back(std::make_tuple(allocValue, *dealloc));
});
}
BufferPlacementTransformationBase::BufferPlacementTransformationBase(
Operation *op)
: aliases(op), allocs(op), liveness(op) {}
FailureOr<memref::GlobalOp>
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
for (Operation &op : moduleOp.getRegion().getOps()) {
auto globalOp = dyn_cast<memref::GlobalOp>(&op);
if (!globalOp)
continue;
if (!globalOp.getInitialValue().has_value())
continue;
uint64_t opAlignment = globalOp.getAlignment().value_or(0);
Attribute initialValue = globalOp.getInitialValue().value();
if (opAlignment == alignment && initialValue == constantOp.getValue())
return globalOp;
}
OpBuilder globalBuilder(moduleOp.getContext());
SymbolTable symbolTable(moduleOp);
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
IntegerAttr memrefAlignment =
alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment)
: IntegerAttr();
BufferizeTypeConverter typeConverter;
auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
if (memorySpace)
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
globalBuilder.getStringAttr("private"),
memrefType,
cast<ElementsAttr>(constantOp.getValue()),
true,
memrefAlignment);
symbolTable.insert(global);
global->moveBefore(&moduleOp.front());
return global;
}