#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace bufferization {
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
}
}
#define DEBUG_TYPE "bufferizable-op-interface"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
using namespace mlir;
using namespace bufferization;
static Operation *getOwnerOfValue(Value value) {
if (auto opResult = value.dyn_cast<OpResult>())
return opResult.getDefiningOp();
return value.cast<BlockArgument>().getOwner()->getParentOp();
}
bool bufferization::allocationDoesNotEscape(OpResult opResult) {
#ifndef NDEBUG
auto bufferizableOp = opResult.getDefiningOp<BufferizableOpInterface>();
assert(bufferizableOp && bufferizableOp.bufferizesToAllocation(opResult) &&
"expected op that bufferizes to an allocation");
#endif
Operation *op = opResult.getDefiningOp();
if (!op->hasAttr(BufferizationDialect::kEscapeAttrName))
return false;
auto attr =
op->getAttrOfType<ArrayAttr>(BufferizationDialect::kEscapeAttrName);
return !attr[opResult.getResultNumber()].cast<BoolAttr>().getValue();
}
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue, bool escape,
const BufferizationOptions &options, bool copy) {
Value tensor;
if (shapedValue.getType().isa<RankedTensorType>()) {
tensor = shapedValue;
} else if (shapedValue.getType().isa<MemRefType>()) {
tensor = b.create<ToTensorOp>(loc, shapedValue);
} else {
llvm_unreachable("expected RankedTensorType or MemRefType");
}
RankedTensorType tensorType = tensor.getType().cast<RankedTensorType>();
SmallVector<Value> dynamicSizes;
if (!copy) {
bool reifiedShapes = false;
if (shapedValue.getType().isa<RankedTensorType>() &&
shapedValue.isa<OpResult>()) {
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
shapedValue.getDefiningOp())) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
reifiedShapes = true;
auto &shape =
resultDims[shapedValue.cast<OpResult>().getResultNumber()];
for (const auto &dim : enumerate(tensorType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynamicSizes.push_back(shape[dim.index()]);
}
}
}
if (!reifiedShapes)
populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
}
auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
copy ? tensor : Value());
allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
b.getBoolArrayAttr({escape}));
if (copy)
return allocTensorOp.getResult();
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
allocTensorOp.setMemorySpaceAttr(
b.getIntegerAttr(b.getIntegerType(64, false),
copyBufferType->getMemorySpaceAsInt()));
return allocTensorOp.getResult();
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
DenseSet<OpOperand *> copiedOpOperands;
DenseSet<OpOperand *> escapingOpOperandCopies;
SmallVector<OpResult> outOfPlaceOpResults;
DenseSet<OpResult> copiedOpResults;
DenseSet<OpResult> escapingOpResultCopies;
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
if (!operandType.isa<TensorType>())
continue;
if (state.isInPlace(opOperand))
continue;
if (operandType.isa<UnrankedTensorType>())
return op->emitError("copies of unranked tensors are not supported");
SmallVector<OpResult> aliasingOpResults =
state.getAliasingOpResult(opOperand);
bool escape = !state.getOptions().createDeallocs ||
llvm::any_of(aliasingOpResults, [&](Value v) {
return state.isTensorYielded(v);
});
if (aliasingOpResults.size() == 1 &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
outOfPlaceOpResults.push_back(aliasingOpResults.front());
if (!state.canOmitTensorCopy(opOperand))
copiedOpResults.insert(aliasingOpResults.front());
if (escape)
escapingOpResultCopies.insert(aliasingOpResults.front());
} else {
outOfPlaceOpOperands.push_back(&opOperand);
if (!state.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
if (escape)
escapingOpOperandCopies.insert(&opOperand);
}
}
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(),
escapingOpOperandCopies.contains(opOperand), state.getOptions(),
copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
}
rewriter.setInsertionPointAfter(op);
for (OpResult opResult : outOfPlaceOpResults) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opResult,
escapingOpResultCopies.contains(opResult), state.getOptions(),
copiedOpResults.count(opResult));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
opResult.getUses(), [](OpOperand &use) { return &use; }));
for (OpOperand *use : uses) {
if (use->getOwner() != copy->getDefiningOp())
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}
return success();
}
bool bufferization::shouldDeallocateOpResult(
OpResult opResult, const BufferizationOptions &options) {
Operation *op = opResult.getOwner();
assert(options.dynCastBufferizableOp(op).bufferizesToAllocation(opResult) &&
"expected that op allocates");
AnalysisState analysisState(options);
if (op->hasAttr(BufferizationDialect::kEscapeAttrName)) {
ArrayAttr escapeAttr =
op->getAttr(BufferizationDialect::kEscapeAttrName).cast<ArrayAttr>();
return !escapeAttr[0].cast<BoolAttr>().getValue();
}
if (options.createDeallocs) {
return !analysisState.isTensorYielded(opResult);
}
return false;
}
bool OpFilter::isOpAllowed(Operation *op) const {
bool isAllowed = !hasAllowRule();
for (const Entry &entry : entries) {
bool filterResult = entry.fn(op);
switch (entry.type) {
case Entry::ALLOW:
isAllowed |= filterResult;
break;
case Entry::DENY:
if (filterResult)
return false;
};
}
return isAllowed;
}
static BaseMemRefType
defaultUnknownTypeConverter(Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
memorySpace);
}
BufferizationOptions::BufferizationOptions()
: unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const {
bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
return false;
return opFilter.isOpAllowed(op);
}
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
if (!bufferizableOp)
return nullptr;
if (!isOpAllowed(op))
return nullptr;
return bufferizableOp;
}
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Value value) const {
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
if (isOpAllowed(bufferizableOp.getOperation()))
return bufferizableOp;
return nullptr;
}
void BufferizationOptions::addDialectStateInitializer(
StringRef name, const DialectStateInitFn &fn) {
stateInitializers.push_back(
[=](AnalysisState &state) { state.insertDialectState(name, fn()); });
}
static void setInsertionPointAfter(OpBuilder &b, Value value) {
if (auto bbArg = value.dyn_cast<BlockArgument>()) {
b.setInsertionPointToStart(bbArg.getOwner());
} else {
b.setInsertionPointAfter(value.getDefiningOp());
}
}
SmallVector<OpOperand *>
AnalysisState::getAliasingOpOperand(OpResult result) const {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
return {};
}
SmallVector<OpResult>
AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
return {};
}
bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
return true;
}
bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
return true;
}
bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
return false;
}
bool AnalysisState::isValueRead(Value value) const {
assert(value.getType().isa<TensorType>() && "expected TensorType");
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
while (!workingSet.empty()) {
OpOperand *uMaybeReading = workingSet.pop_back_val();
if (bufferizesToAliasOnly(*uMaybeReading))
for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
for (OpOperand &use : opResult.getUses())
workingSet.push_back(&use);
if (bufferizesToMemoryRead(*uMaybeReading))
return true;
}
return false;
}
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
if (condition(value) || value.isa<BlockArgument>()) {
result.insert(value);
continue;
}
OpResult opResult = value.cast<OpResult>();
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
result.insert(value);
continue;
}
for (OpOperand *o : opOperands)
workingSet.insert(o->get());
}
return result;
}
llvm::SetVector<Value>
AnalysisState::findLastPrecedingWrite(Value value) const {
return findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return true;
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
return true;
return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
});
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
: options(options) {
for (const BufferizationOptions::AnalysisStateInitFn &fn :
options.stateInitializers)
fn(*this);
}
bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
if (hasUndefinedContents(&opOperand))
return true;
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
return true;
SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
if (!bufferizesToMemoryRead(opOperand) &&
llvm::none_of(aliasingOpResults,
[&](OpResult opResult) { return isValueRead(opResult); }))
return true;
return false;
}
bool AnalysisState::isInPlace(OpOperand &opOperand) const {
if (isa<ToMemrefOp>(opOperand.getOwner()))
return true;
return !bufferizesToMemoryWrite(opOperand);
}
bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
return false;
}
bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
return true;
}
bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}
bool AnalysisState::isTensorYielded(Value tensor) const {
if (!tensor.getDefiningOp<AllocTensorOp>())
return true;
SmallVector<OpOperand *> worklist;
for (OpOperand &use : tensor.getUses())
worklist.push_back(&use);
while (!worklist.empty()) {
OpOperand *operand = worklist.pop_back_val();
Operation *op = operand->getOwner();
if (!options.dynCastBufferizableOp(op))
continue;
if (isa<ToMemrefOp>(op))
return true;
if (isRegionReturnLike(op))
return true;
for (OpResult opResult : getAliasingOpResult(*operand))
for (OpOperand &use : opResult.getUses())
worklist.push_back(&use);
}
return false;
}
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>();
assert((!rankedTensorType || memrefType.cast<MemRefType>().getRank() ==
rankedTensorType.getRank()) &&
"to_memref would be invalid: mismatching ranks");
#endif
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
#endif
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref();
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
if (failed(memrefType))
return failure();
ensureToMemrefOpIsValid(value, *memrefType);
return rewriter
.create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
.getResult();
}
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
Operation *op = getOwnerOfValue(value);
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
if (auto bbArg = value.dyn_cast<BlockArgument>())
if (auto bufferizableOp =
options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
return bufferizableOp.getBufferType(bbArg, options);
Optional<unsigned> memorySpace = None;
if (auto opResult = value.dyn_cast<OpResult>()) {
if (auto bufferizableOp =
options.dynCastBufferizableOp(opResult.getDefiningOp())) {
if (bufferizableOp.bufferizesToAllocation(opResult)) {
FailureOr<unsigned> queriedMemorySpace =
bufferizableOp.getMemorySpace(opResult);
if (!failed(queriedMemorySpace))
memorySpace = *queriedMemorySpace;
}
}
}
if (!memorySpace.has_value())
memorySpace = options.defaultMemorySpace;
if (!memorySpace.has_value())
return op->emitError("could not infer memory space");
return getMemRefType(value, options, {}, *memorySpace);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
assert(values.size() == op->getNumResults() &&
"expected one value per OpResult");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
if (opResult.getType().isa<TensorType>()) {
assert((replacement.getType().isa<MemRefType>() ||
replacement.getType().isa<UnrankedMemRefType>()) &&
"tensor op result should be replaced with a memref value");
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
replacement.getLoc(), replacement);
}
replacements.push_back(replacement);
}
rewriter.replaceOp(op, replacements);
}
FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
MemRefType type,
ValueRange dynShape) const {
if (allocationFn)
return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
if (bufferAlignment != 0)
return b
.create<memref::AllocOp>(loc, type, dynShape,
b.getI64IntegerAttr(bufferAlignment))
.getResult();
return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
}
LogicalResult BufferizationOptions::createDealloc(OpBuilder &b, Location loc,
Value allocatedBuffer) const {
if (deallocationFn)
return (*deallocationFn)(b, loc, allocatedBuffer);
b.create<memref::DeallocOp>(loc, allocatedBuffer);
return success();
}
LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
Value from, Value to) const {
if (memCpyFn)
return (*memCpyFn)(b, loc, from, to);
b.create<memref::CopyOp>(loc, from, to);
return success();
}
bool bufferization::isFunctionArgument(Value value) {
auto bbArg = value.dyn_cast<BlockArgument>();
if (!bbArg)
return false;
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
}
BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
unsigned memorySpace) {
auto tensorType = value.getType().cast<TensorType>();
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpaceAttr);
}
auto rankedTensorType = tensorType.cast<RankedTensorType>();
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpaceAttr);
}
return options.unknownTypeConverterFn(value, memorySpace, options);
}
BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
unsigned memorySpace) {
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
auto rankedTensorType = tensorType.cast<RankedTensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, rankedTensorType.getContext());
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), stridedLayout,
memorySpaceAttr);
}
BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
unsigned memorySpace) {
if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
auto rankedTensorType = tensorType.cast<RankedTensorType>();
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpaceAttr);
}