#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::scf;
namespace mlir {
namespace scf {
namespace {
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
if (buffer.getType() == type)
return buffer;
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
"scf.while op bufferization: cast incompatible");
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
}
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
assert(region->getBlocks().size() == 1 &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
if (llvm::is_contained(exceptions, alias))
return;
Region *aliasRegion = alias.getParentRegion();
if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
result = false;
if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
result = false;
});
return result;
}
struct ConditionOpInterface
: public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
scf::ConditionOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto conditionOp = cast<scf::ConditionOp>(op);
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
SmallVector<Value> newArgs;
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
whileOp.getAfterArguments()[it.index()], options);
if (failed(resultType))
return failure();
Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
newArgs.push_back(buffer);
} else {
newArgs.push_back(value);
}
}
replaceOpWithNewBufferizedOp<scf::ConditionOp>(
rewriter, op, conditionOp.getCondition(), newArgs);
return success();
}
};
static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
scf::YieldOp result;
for (Block &block : executeRegionOp.getRegion()) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
if (result)
return {};
result = yieldOp;
}
}
return result;
}
struct ExecuteRegionOpInterface
: public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
static bool supportsUnstructuredControlFlow() { return true; }
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
return true;
}
LogicalResult verifyAnalysis(Operation *op,
const AnalysisState &state) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
if (!getUniqueYieldOp(executeRegionOp))
return op->emitOpError("op without unique scf.yield is not supported");
return success();
}
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
if (auto bbArg = dyn_cast<BlockArgument>(value))
return getAliasingBranchOpOperands(op, bbArg, state);
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
auto it = llvm::find(op->getOpResults(), value);
assert(it != op->getOpResults().end() && "invalid value");
size_t resultNum = std::distance(op->getOpResults().begin(), it);
auto yieldOp = getUniqueYieldOp(executeRegionOp);
if (!yieldOp)
return {};
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
auto yieldOp = getUniqueYieldOp(executeRegionOp);
TypeRange newResultTypes(yieldOp.getResults());
auto newOp =
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
newOp.getRegion().takeBody(executeRegionOp.getRegion());
for (Block &block : newOp.getRegion())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
options)))
return failure();
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
executeRegionOp.getLoc(), newOp->getResult(it.index())));
} else {
newResults.push_back(newOp->getResult(it.index()));
}
}
rewriter.replaceOp(executeRegionOp, newResults);
return success();
}
};
struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
auto ifOp = cast<scf::IfOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
return {{thenOperand, BufferRelation::Equivalent, false},
{elseOperand, BufferRelation::Equivalent, false}};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto ifOp = cast<scf::IfOp>(op);
SmallVector<Type> newTypes;
for (Value result : ifOp.getResults()) {
if (!isa<TensorType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
auto bufferType = bufferization::getBufferType(result, options);
if (failed(bufferType))
return failure();
newTypes.push_back(*bufferType);
}
rewriter.setInsertionPoint(ifOp);
auto newIfOp =
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
true);
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto ifOp = cast<scf::IfOp>(op);
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
assert(value.getDefiningOp() == op && "invalid valid");
auto opResult = cast<OpResult>(value);
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
BaseMemRefType thenBufferType, elseBufferType;
if (isa<BaseMemRefType>(thenValue.getType())) {
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(thenValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
thenBufferType = *maybeBufferType;
}
if (isa<BaseMemRefType>(elseValue.getType())) {
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(elseValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
elseBufferType = *maybeBufferType;
}
if (thenBufferType == elseBufferType)
return thenBufferType;
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
return getMemRefTypeWithFullyDynamicLayout(
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
}
};
struct IndexSwitchOpInterface
: public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
scf::IndexSwitchOp> {
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
auto switchOp = cast<scf::IndexSwitchOp>(op);
int64_t resultNum = cast<OpResult>(value).getResultNumber();
AliasingOpOperandList result;
for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
auto yieldOp =
cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
BufferRelation::Equivalent,
false));
}
auto defaultYieldOp =
cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
BufferRelation::Equivalent,
false));
return result;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto switchOp = cast<scf::IndexSwitchOp>(op);
SmallVector<Type> newTypes;
for (Value result : switchOp.getResults()) {
if (!isa<TensorType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
auto bufferType = bufferization::getBufferType(result, options);
if (failed(bufferType))
return failure();
newTypes.push_back(*bufferType);
}
rewriter.setInsertionPoint(switchOp);
auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
switchOp.getCases().size());
for (auto [src, dest] :
llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
rewriter.inlineRegionBefore(src, dest, dest.begin());
rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
newSwitchOp.getDefaultRegion(),
newSwitchOp.getDefaultRegion().begin());
replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto switchOp = cast<scf::IndexSwitchOp>(op);
assert(value.getDefiningOp() == op && "invalid value");
int64_t resultNum = cast<OpResult>(value).getResultNumber();
SmallVector<BaseMemRefType> yieldedTypes;
auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
Value yieldedValue = yieldOp->getOperand(resultNum);
if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
return bufferType;
auto maybeBufferType =
bufferization::getBufferType(yieldedValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
return maybeBufferType;
};
auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
if (failed(maybeBufferType))
return failure();
BaseMemRefType bufferType = *maybeBufferType;
for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
if (failed(yieldedBufferType))
return failure();
if (bufferType == *yieldedBufferType)
continue;
if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
return op->emitError("inconsistent memory space on switch cases");
bufferType = getMemRefTypeWithFullyDynamicLayout(
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
return bufferType;
}
};
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> result;
for (const auto &it : llvm::enumerate(values))
if (isa<TensorType>(it.value().getType()))
result.insert(it.index());
return result;
}
DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
ValueRange yieldedValues,
const AnalysisState &state) {
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
for (unsigned int i = 0; i < minSize; ++i) {
if (!isa<TensorType>(bbArgs[i].getType()) ||
!isa<TensorType>(yieldedValues[i].getType()))
continue;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
}
return result;
}
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (isa<TensorType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(resultBuffer))
return failure();
result.push_back(*resultBuffer);
} else {
result.push_back(opOperand.get());
}
}
return result;
}
static SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
const DenseSet<int64_t> &tensorIndices) {
SmallVector<Value> result;
for (const auto &it : llvm::enumerate(bbArgs)) {
size_t idx = it.index();
Value val = it.value();
if (tensorIndices.contains(idx)) {
result.push_back(
rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
.getResult());
} else {
result.push_back(val);
}
}
return result;
}
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
auto initArgBufferType =
bufferization::getBufferType(initArg, options, invocationStack);
if (failed(initArgBufferType))
return failure();
if (llvm::count(invocationStack, iterArg) >= 2) {
return *initArgBufferType;
}
BaseMemRefType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(yieldedValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
}
if (*initArgBufferType == yieldedValueBufferType)
return yieldedValueBufferType;
auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
auto iterTensorType = cast<TensorType>(iterArg.getType());
auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
return loopOp->emitOpError(
"init_arg and yielded value bufferize to inconsistent memory spaces");
#ifndef NDEBUG
if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
assert(
llvm::all_equal({yieldedRankedBufferType.getShape(),
cast<MemRefType>(initBufferType).getShape(),
cast<RankedTensorType>(iterTensorType).getShape()}) &&
"expected same shape");
}
#endif
return getMemRefTypeWithFullyDynamicLayout(
iterTensorType, yieldedBufferType.getMemorySpace());
}
bool mayHaveZeroIterations(scf::ForOp forOp) {
std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
if (!lb.has_value() || !ub.has_value())
return true;
return *ub <= *lb;
}
struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
scf::ForOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
if (mayHaveZeroIterations(forOp))
return true;
return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
BufferRelation relation = bufferRelation(op, opResult, state);
return {{opResult, relation,
relation == BufferRelation::Equivalent}};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
return true;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
return failure();
if (!state.getOptions().enforceAliasingInvariants)
return success();
auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(yieldOp);
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
SmallVector<Value> yieldValues;
for (const auto it : llvm::enumerate(yieldOp.getResults())) {
if (!indices.contains(it.index()) ||
doesNotAliasExternalValue(
it.value(), &forOp.getRegion(),
forOp.getRegionIterArg(it.index()),
static_cast<const OneShotAnalysisState &>(state))) {
yieldValues.push_back(it.value());
continue;
}
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
if (failed(alloc))
return failure();
yieldValues.push_back(*alloc);
}
rewriter.modifyOpInPlace(
yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
if (auto opResult = dyn_cast<OpResult>(value)) {
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
return bufferization::getBufferType(bbArg, options, invocationStack);
}
BlockArgument bbArg = cast<BlockArgument>(value);
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = yieldOp.getOperand(resultNum);
BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
Value initArg = forOp.getInitArgs()[resultNum];
return computeLoopRegionIterArgBufferType(
op, iterArg, initArg, yieldedValue, options, invocationStack);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = forOp.getBody();
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
Value result = forOp->getResult(it.index());
if (!isa<TensorType>(result.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
auto targetType = bufferization::getBufferType(result, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
}
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), castedInitArgs);
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
rewriter.setInsertionPointToStart(loopBody);
SmallVector<Value> iterArgs =
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
return success();
}
LogicalResult verifyAnalysis(Operation *op,
const AnalysisState &state) const {
const auto &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
if (options.allowReturnAllocsFromLoops)
return success();
auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpResult opResult : op->getOpResults()) {
if (!isa<TensorType>(opResult.getType()))
continue;
if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
return yieldOp->emitError()
<< "Yield operand #" << opResult.getResultNumber()
<< " is not equivalent to the corresponding iter bbArg";
}
return success();
}
};
struct WhileOpInterface
: public BufferizableOpInterface::ExternalModel<WhileOpInterface,
scf::WhileOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto whileOp = cast<scf::WhileOp>(op);
unsigned int idx = opOperand.getOperandNumber();
if (idx >= op->getNumResults() ||
opOperand.get().getType() != op->getResult(idx).getType())
return {};
OpResult opResult = whileOp->getResult(idx);
BufferRelation relation = bufferRelation(op, opResult, state);
return {{opResult, relation,
relation == BufferRelation::Equivalent}};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
unsigned int resultNumber = opResult.getResultNumber();
auto whileOp = cast<scf::WhileOp>(op);
if (resultNumber >= whileOp.getBeforeArguments().size())
return BufferRelation::Unknown;
if (opResult.getType() !=
whileOp.getBeforeArguments()[resultNumber].getType())
return BufferRelation::Unknown;
auto conditionOp = whileOp.getConditionOp();
BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
Value conditionOperand = conditionOp.getArgs()[resultNumber];
bool equivCondition =
state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
auto yieldOp = whileOp.getYieldOp();
BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
Value yieldOperand = yieldOp.getOperand(resultNumber);
bool equivYield =
state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
return equivCondition && equivYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
return true;
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
return failure();
if (!state.getOptions().enforceAliasingInvariants)
return success();
OpBuilder::InsertionGuard g(rewriter);
auto whileOp = cast<scf::WhileOp>(op);
auto conditionOp = whileOp.getConditionOp();
DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
rewriter.setInsertionPoint(conditionOp);
SmallVector<Value> beforeYieldValues;
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
if (!isa<TensorType>(value.getType()) ||
(equivalentYieldsAfter.contains(idx) &&
equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
continue;
}
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, conditionOp.getLoc(), value, state.getOptions());
if (failed(alloc))
return failure();
beforeYieldValues.push_back(*alloc);
}
rewriter.modifyOpInPlace(conditionOp, [&]() {
conditionOp.getArgsMutable().assign(beforeYieldValues);
});
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
DenseSet<int64_t> indicesAfter =
getTensorIndices(whileOp.getAfterArguments());
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, whileOp.getInitsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
SmallVector<Value> castedInitArgs;
for (const auto &it : llvm::enumerate(initArgs)) {
Value initArg = it.value();
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
if (!isa<TensorType>(beforeArg.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
auto targetType = bufferization::getBufferType(beforeArg, options);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
}
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
return llvm::cast<Type>(
*bufferization::getBufferType(bbArg, options));
}));
ValueRange argsRangeBefore(castedInitArgs);
TypeRange argsTypesBefore(argsRangeBefore);
auto newWhileOp = rewriter.create<scf::WhileOp>(
whileOp.getLoc(), argsTypesAfter, castedInitArgs);
SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
whileOp.getLoc());
SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
whileOp.getLoc());
Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
rewriter.setInsertionPointToStart(newBeforeBody);
SmallVector<Value> newBeforeArgs = getBbArgReplacements(
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
rewriter.setInsertionPointToStart(newAfterBody);
SmallVector<Value> newAfterArgs = getBbArgReplacements(
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
if (auto bbArg = dyn_cast<BlockArgument>(value)) {
if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
auto yieldOp = whileOp.getYieldOp();
Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
return computeLoopRegionIterArgBufferType(
op, bbArg, initArg, yieldedValue, options, invocationStack);
}
}
unsigned resultNum;
if (auto opResult = dyn_cast<OpResult>(value)) {
resultNum = opResult.getResultNumber();
} else if (cast<BlockArgument>(value).getOwner()->getParent() ==
&whileOp.getAfter()) {
resultNum = cast<BlockArgument>(value).getArgNumber();
} else {
llvm_unreachable("invalid value");
}
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
return bufferization::getBufferType(conditionYieldedVal, options,
invocationStack);
}
LogicalResult verifyAnalysis(Operation *op,
const AnalysisState &state) const {
auto whileOp = cast<scf::WhileOp>(op);
const auto &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
if (options.allowReturnAllocsFromLoops)
return success();
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Block *block = conditionOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
block->getArgument(it.index())))
return conditionOp->emitError()
<< "Condition arg #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
}
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Block *block = yieldOp->getBlock();
if (!isa<TensorType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
block->getArgument(it.index())))
return yieldOp->emitError()
<< "Yield operand #" << it.index()
<< " is not equivalent to the corresponding iter bbArg";
}
return success();
}
};
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
BufferRelation::Equivalent, false}};
}
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
BufferRelation::Equivalent}};
return {};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
scf::WhileOp>(yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
yieldOp->getParentOp()->getResult(it.index()), options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
} else if (auto whileOp =
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
whileOp.getBeforeArguments()[it.index()], options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
}
newResults.push_back(buffer);
} else {
newResults.push_back(value);
}
}
replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
return success();
}
};
bool mayHaveZeroIterations(scf::ForallOp forallOp) {
for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound())) {
std::optional<int64_t> lbConst = getConstantIntValue(lb);
std::optional<int64_t> ubConst = getConstantIntValue(ub);
if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
return true;
}
return false;
}
struct ForallOpInterface
: public BufferizableOpInterface::ExternalModel<ForallOpInterface,
ForallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forallOp = cast<ForallOp>(op);
if (mayHaveZeroIterations(forallOp))
return true;
return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forallOp = cast<ForallOp>(op);
return {
{{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
}
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard guard(rewriter);
auto forallOp = cast<ForallOp>(op);
int64_t rank = forallOp.getRank();
SmallVector<Value> buffers;
for (Value out : forallOp.getOutputs()) {
FailureOr<Value> buffer = getBuffer(rewriter, out, options);
if (failed(buffer))
return failure();
buffers.push_back(*buffer);
}
rewriter.setInsertionPointToStart(forallOp.getBody());
for (const auto &it : llvm::zip(
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
BlockArgument bbArg = std::get<0>(it);
Value buffer = std::get<1>(it);
Value bufferAsTensor =
rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
bbArg.replaceAllUsesWith(bufferAsTensor);
}
rewriter.setInsertionPoint(forallOp);
ForallOp newForallOp;
newForallOp = rewriter.create<ForallOp>(
forallOp.getLoc(), forallOp.getMixedLowerBound(),
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
ValueRange(), forallOp.getMapping());
newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
rewriter.eraseOp(newForallOp.getBody()->getTerminator());
SmallVector<Value> replacementBbArgs;
replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
newForallOp.getBody()->getArguments().end());
replacementBbArgs.append(forallOp.getOutputs().size(), Value());
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
replacementBbArgs);
replaceOpWithBufferizedValues(rewriter, op, buffers);
return success();
}
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto forallOp = cast<ForallOp>(op);
if (auto bbArg = dyn_cast<BlockArgument>(value))
return bufferization::getBufferType(
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
auto forallOp = cast<ForallOp>(op);
for (auto [lb, ub, step] :
llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep())) {
std::optional<int64_t> lbConstant = getConstantIntValue(lb);
if (!lbConstant)
return true;
std::optional<int64_t> ubConstant = getConstantIntValue(ub);
if (!ubConstant)
return true;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
if (!stepConstant)
return true;
if (*lbConstant + *stepConstant < *ubConstant)
return true;
}
return false;
}
bool isParallelRegion(Operation *op, unsigned index) const {
return isRepetitiveRegion(op, index);
}
};
struct InParallelOpInterface
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
InParallelOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
const BufferizationOptions &options) const {
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
return failure();
}
};
}
}
}
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
ForOp::attachInterface<ForOpInterface>(*ctx);
IfOp::attachInterface<IfOpInterface>(*ctx);
IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
ForallOp::attachInterface<ForallOpInterface>(*ctx);
InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
WhileOp::attachInterface<WhileOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
});
}