#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace bufferization {
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
}
}
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
#define DEBUG_TYPE "bufferizable-op-interface"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X))
using namespace mlir;
using namespace bufferization;
static bool isRepetitiveRegion(Region *region,
const BufferizationOptions &options) {
Operation *op = region->getParentOp();
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
return true;
return false;
}
Region *AnalysisState::getEnclosingRepetitiveRegion(
Operation *op, const BufferizationOptions &options) {
if (!op->getBlock())
return nullptr;
if (auto iter = enclosingRepetitiveRegionCache.find_as(op);
iter != enclosingRepetitiveRegionCache.end())
return iter->second;
return enclosingRepetitiveRegionCache[op] =
getEnclosingRepetitiveRegion(op->getBlock(), options);
}
Region *AnalysisState::getEnclosingRepetitiveRegion(
Value value, const BufferizationOptions &options) {
if (auto iter = enclosingRepetitiveRegionCache.find_as(value);
iter != enclosingRepetitiveRegionCache.end())
return iter->second;
Region *region = value.getParentRegion();
SmallVector<Region *> visitedRegions;
while (region) {
visitedRegions.push_back(region);
if (isRepetitiveRegion(region, options))
break;
region = region->getParentRegion();
}
enclosingRepetitiveRegionCache[value] = region;
for (Region *r : visitedRegions)
enclosingRepetitiveRegionCache[r] = region;
return region;
}
Region *AnalysisState::getEnclosingRepetitiveRegion(
Block *block, const BufferizationOptions &options) {
if (auto iter = enclosingRepetitiveRegionCache.find_as(block);
iter != enclosingRepetitiveRegionCache.end())
return iter->second;
Region *region = block->getParent();
Operation *op = nullptr;
SmallVector<Region *> visitedRegions;
do {
op = region->getParentOp();
if (isRepetitiveRegion(region, options))
break;
} while ((region = op->getParentRegion()));
enclosingRepetitiveRegionCache[block] = region;
for (Region *r : visitedRegions)
enclosingRepetitiveRegionCache[r] = region;
return region;
}
void AnalysisState::resetCache() { enclosingRepetitiveRegionCache.clear(); }
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
while ((region = region->getParentRegion())) {
if (isRepetitiveRegion(region, options))
break;
}
return region;
}
Region *bufferization::getParallelRegion(Region *region,
const BufferizationOptions &options) {
while (region) {
auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp());
if (bufferizableOp &&
bufferizableOp.isParallelRegion(region->getRegionNumber())) {
assert(isRepetitiveRegion(region, options) &&
"expected that all parallel regions are also repetitive regions");
return region;
}
region = region->getParentRegion();
}
return nullptr;
}
Operation *bufferization::getOwnerOfValue(Value value) {
if (auto opResult = llvm::dyn_cast<OpResult>(value))
return opResult.getDefiningOp();
return llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
}
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
} else if (llvm::isa<MemRefType>(shapedValue.getType())) {
tensor = b.create<ToTensorOp>(loc, shapedValue);
} else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) ||
llvm::isa<UnrankedMemRefType>(shapedValue.getType())) {
return getOwnerOfValue(shapedValue)
->emitError("copying of unranked tensors is not implemented");
} else {
llvm_unreachable("expected RankedTensorType or MemRefType");
}
RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.getType());
SmallVector<Value> dynamicSizes;
if (!copy) {
bool reifiedShapes = false;
if (llvm::isa<RankedTensorType>(shapedValue.getType()) &&
llvm::isa<OpResult>(shapedValue)) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(
reifyResultShapes(b, shapedValue.getDefiningOp(), resultDims))) {
reifiedShapes = true;
auto &shape =
resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
for (const auto &dim : enumerate(tensorType.getShape()))
if (ShapedType::isDynamic(dim.value()))
dynamicSizes.push_back(shape[dim.index()].get<Value>());
}
}
if (!reifiedShapes)
populateDynamicDimSizes(b, loc, tensor, dynamicSizes);
}
auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes,
copy ? tensor : Value());
if (copy)
return allocTensorOp.getResult();
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
if (!memorySpace)
memorySpace = options.defaultMemorySpaceFn(tensorType);
if (memorySpace.has_value())
allocTensorOp.setMemorySpaceAttr(memorySpace.value());
return allocTensorOp.getResult();
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &state) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
DenseSet<OpOperand *> copiedOpOperands;
SmallVector<Value> outOfPlaceValues;
DenseSet<Value> copiedOpValues;
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
if (!state.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
outOfPlaceOpOperands.push_back(&opOperand);
if (!state.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
}
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), value, state.getOptions(),
copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
for (OpOperand *use : uses) {
if (use->getOwner() == copy->getDefiningOp())
continue;
if (isa<tensor::DimOp>(use->getOwner()))
continue;
rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}
return success();
}
bool OpFilter::isOpAllowed(Operation *op) const {
bool isAllowed = !hasAllowRule();
for (const Entry &entry : entries) {
bool filterResult = entry.fn(op);
switch (entry.type) {
case Entry::ALLOW:
isAllowed |= filterResult;
break;
case Entry::DENY:
if (filterResult)
return false;
};
}
return isAllowed;
}
namespace {
BaseMemRefType
defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
}
BaseMemRefType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(
llvm::cast<TensorType>(value.getType()), memorySpace);
}
}
BufferizationOptions::BufferizationOptions()
: functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const {
bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->getDialect());
if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
return false;
return opFilter.isOpAllowed(op);
}
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (!isOpAllowed(op))
return nullptr;
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
if (!bufferizableOp)
return nullptr;
return bufferizableOp;
}
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Value value) const {
return dynCastBufferizableOp(getOwnerOfValue(value));
}
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
memorySpace);
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
}
static void setInsertionPointAfter(OpBuilder &b, Value value) {
if (auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
b.setInsertionPointToStart(bbArg.getOwner());
} else {
b.setInsertionPointAfter(value.getDefiningOp());
}
}
AliasingOpOperandList AnalysisState::getAliasingOpOperands(Value value) const {
if (Operation *op = getOwnerOfValue(value))
if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op))
return bufferizableOp.getAliasingOpOperands(value, *this);
return detail::unknownGetAliasingOpOperands(value);
}
AliasingValueList AnalysisState::getAliasingValues(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.getAliasingValues(opOperand, *this);
return detail::unknownGetAliasingValues(opOperand);
}
bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
return true;
}
bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
return true;
}
bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
return false;
}
bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
auto opResult = llvm::dyn_cast<OpResult>(value);
if (!opResult)
return true;
auto bufferizableOp = getOptions().dynCastBufferizableOp(value);
if (!bufferizableOp)
return true;
return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *this);
}
bool AnalysisState::isValueRead(Value value) const {
assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
while (!workingSet.empty()) {
OpOperand *uMaybeReading = workingSet.pop_back_val();
if (visited.contains(uMaybeReading))
continue;
visited.insert(uMaybeReading);
if (bufferizesToAliasOnly(*uMaybeReading))
for (AliasingValue alias : getAliasingValues(*uMaybeReading))
for (OpOperand &use : alias.value.getUses())
workingSet.push_back(&use);
if (bufferizesToMemoryRead(*uMaybeReading))
return true;
}
return false;
}
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
TraversalConfig config) const {
llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
visited.insert(value);
if (condition(value)) {
result.insert(value);
continue;
}
if (!config.followUnknownOps && !options.dynCastBufferizableOp(value)) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
AliasingOpOperandList aliases = getAliasingOpOperands(value);
if (aliases.getNumAliases() == 0) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
for (AliasingOpOperand a : aliases) {
if (config.followEquivalentOnly &&
a.relation != BufferRelation::Equivalent) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
if (config.followSameTypeOrCastsOnly &&
a.opOperand->get().getType() != value.getType() &&
!value.getDefiningOp<CastOpInterface>()) {
if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
workingSet.insert(a.opOperand->get());
}
}
return result;
}
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
TraversalConfig config;
config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
: AnalysisState(options, TypeID::get<AnalysisState>()) {}
AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
: options(options), type(type) {
for (const BufferizationOptions::AnalysisStateInitFn &fn :
options.stateInitializers)
fn(*this);
}
bool AnalysisState::canOmitTensorCopy(OpOperand &opOperand) const {
if (hasUndefinedContents(&opOperand))
return true;
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
return true;
AliasingValueList aliases = getAliasingValues(opOperand);
if (!bufferizesToMemoryRead(opOperand) &&
llvm::none_of(aliases,
[&](AliasingValue a) { return isValueRead(a.value); }))
return true;
return false;
}
bool AnalysisState::isInPlace(OpOperand &opOperand) const {
if (isa<ToMemrefOp>(opOperand.getOwner()))
return true;
return !bufferizesToMemoryWrite(opOperand);
}
bool AnalysisState::areEquivalentBufferizedValues(Value v1, Value v2) const {
return false;
}
bool AnalysisState::areAliasingBufferizedValues(Value v1, Value v2) const {
return true;
}
bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
rankedTensorType.getRank()) &&
"to_memref would be invalid: mismatching ranks");
#endif
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref();
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
if (failed(memrefType))
return failure();
ensureToMemrefOpIsValid(value, *memrefType);
return rewriter
.create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
.getResult();
}
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
SmallVector<Value> invocationStack;
return getBufferType(value, options, invocationStack);
}
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
invocationStack.push_back(value);
auto popFromStack =
llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, invocationStack);
auto memSpace =
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
return getMemRefType(value, options, {}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.hasTensorSemantics();
return detail::defaultHasTensorSemantics(op);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
assert(values.size() == op->getNumResults() &&
"expected one value per OpResult");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
if (llvm::isa<TensorType>(opResult.getType())) {
assert((llvm::isa<MemRefType>(replacement.getType()) ||
llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
"tensor op result should be replaced with a memref value");
rewriter.setInsertionPointAfter(op);
replacement = rewriter.create<bufferization::ToTensorOp>(
replacement.getLoc(), replacement);
}
replacements.push_back(replacement);
}
rewriter.replaceOp(op, replacements);
}
FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc,
MemRefType type,
ValueRange dynShape) const {
if (allocationFn)
return (*allocationFn)(b, loc, type, dynShape, bufferAlignment);
if (bufferAlignment != 0)
return b
.create<memref::AllocOp>(loc, type, dynShape,
b.getI64IntegerAttr(bufferAlignment))
.getResult();
return b.create<memref::AllocOp>(loc, type, dynShape).getResult();
}
LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
Value from, Value to) const {
if (memCpyFn)
return (*memCpyFn)(b, loc, from, to);
b.create<memref::CopyOp>(loc, from, to);
return success();
}
BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
auto tensorType = llvm::cast<TensorType>(value.getType());
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
assert(!layout && "UnrankedTensorType cannot have a layout map");
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
if (layout) {
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpace);
}
return options.unknownTypeConverterFn(value, memorySpace, options);
}
BaseMemRefType
bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
Attribute memorySpace) {
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
int64_t dynamicOffset = ShapedType::kDynamic;
SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
ShapedType::kDynamic);
auto stridedLayout = StridedLayoutAttr::get(tensorType.getContext(),
dynamicOffset, dynamicStrides);
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), stridedLayout,
memorySpace);
}
BaseMemRefType
bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
Attribute memorySpace) {
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
memorySpace);
}
auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
MemRefLayoutAttrInterface layout = {};
return MemRefType::get(rankedTensorType.getShape(),
rankedTensorType.getElementType(), layout,
memorySpace);
}
bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
OpResult opResult, const AnalysisState &state) {
auto bufferizableOp = cast<BufferizableOpInterface>(opResult.getDefiningOp());
AliasingOpOperandList opOperands =
bufferizableOp.getAliasingOpOperands(opResult, state);
if (opOperands.getAliases().empty())
return true;
if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) {
return state.bufferizesToMemoryWrite(*alias.opOperand);
}))
return true;
auto isMemoryWriteInsideOp = [&](Value v) {
Operation *op = getOwnerOfValue(v);
if (!opResult.getDefiningOp()->isAncestor(op))
return false;
return state.bufferizesToMemoryWrite(v);
};
TraversalConfig config;
config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
.findValueInReverseUseDefChain(alias.opOperand->get(),
isMemoryWriteInsideOp, config)
.empty())
return true;
}
return false;
}
AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Value value, const AnalysisState &state) {
Operation *op = getOwnerOfValue(value);
SmallVector<AliasingOpOperand> result;
for (OpOperand &opOperand : op->getOpOperands()) {
if (!llvm::isa<TensorType>(opOperand.get().getType()))
continue;
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
for (const auto &it : aliasingValues)
if (it.value == value)
result.emplace_back(&opOperand, it.relation, it.isDefinite);
}
return AliasingOpOperandList(std::move(result));
}
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
if (llvm::isa<BlockArgument>(value))
return bufferization::getMemRefType(value, options);
Operation *op = getOwnerOfValue(value);
auto opResult = llvm::cast<OpResult>(value);
AnalysisState state(options);
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() > 0 &&
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
return getBufferType(equivalentOperand, options, invocationStack);
}
auto memSpace =
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
return getMemRefType(value, options, {}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
BufferizableOpInterface bufferizableOp, unsigned index) {
assert(index < bufferizableOp->getNumRegions() && "invalid region index");
auto regionInterface =
dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
if (!regionInterface)
return false;
return regionInterface.isRepetitiveRegion(index);
}
AliasingOpOperandList
bufferization::detail::unknownGetAliasingOpOperands(Value value) {
if (auto bbArg = dyn_cast<BlockArgument>(value))
if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
return {};
AliasingOpOperandList r;
for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
if (isa<TensorType>(operand.get().getType()))
r.addAlias({&operand, BufferRelation::Unknown, false});
return r;
}
AliasingValueList
bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
AliasingValueList r;
for (OpResult result : opOperand.getOwner()->getOpResults())
if (llvm::isa<TensorType>(result.getType()))
r.addAlias({result, BufferRelation::Unknown, false});
for (Region ®ion : opOperand.getOwner()->getRegions())
if (!region.getBlocks().empty())
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
if (isa<TensorType>(bbArg.getType()))
r.addAlias({bbArg, BufferRelation::Unknown, false});
return r;
}
bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
return any_of(r.getBlocks(), [&](Block &b) {
return any_of(b.getArguments(), [&](BlockArgument bbArg) {
return isaTensor(bbArg.getType());
});
});
});
if (hasTensorBlockArgument)
return true;
if (any_of(op->getResultTypes(), isaTensor))
return true;
return any_of(op->getOperandTypes(), isaTensor);
}