#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::vector;
namespace mlir {
namespace vector {
namespace {
struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
if (failed(buffer))
return failure();
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(),
readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
return success();
}
};
struct TransferWriteOpInterface
: public DstBufferizableOpInterfaceExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
if (!writeOp.getShapedType().hasStaticShape())
return true;
for (Value offset : writeOp.getIndices()) {
if (getConstantIntValue(offset) != 0)
return true;
}
if (writeOp.isMasked())
return true;
for (auto [d0, d1] : llvm::zip(writeOp.getShapedType().getShape(),
writeOp.getVectorType().getShape())) {
if (d0 > d1)
return true;
}
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(isa<TensorType>(writeOp.getShapedType()) &&
"only tensor types expected");
FailureOr<Value> resultBuffer =
getBuffer(rewriter, writeOp.getSource(), options);
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.getVector(), *resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
writeOp.getMask(), writeOp.getInBoundsAttr());
replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
return success();
}
};
struct GatherOpInterface
: public BufferizableOpInterface::ExternalModel<GatherOpInterface,
vector::GatherOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto gatherOp = cast<vector::GatherOp>(op);
assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
if (failed(buffer))
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(),
gatherOp.getPassThru());
return success();
}
};
struct MaskOpInterface
: public BufferizableOpInterface::ExternalModel<MaskOpInterface,
vector::MaskOp> {
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
auto maskOp = cast<vector::MaskOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
auto yieldOp =
cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
return failure();
auto maskOp = cast<vector::MaskOp>(op);
if (!maskOp.getMaskRegion()
.front()
.getOps<bufferization::AllocTensorOp>()
.empty())
return op->emitOpError("body must bufferize in-place");
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto maskOp = cast<vector::MaskOp>(op);
Operation *maskedOp = maskOp.getMaskableOp();
if (!options.dynCastBufferizableOp(maskedOp))
return success();
auto yieldOp =
cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
SmallVector<Value> newYieldedValues;
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
if (llvm::is_contained(maskedOp->getOpResults(), it.value())) {
newYieldedValues.push_back(it.value());
} else {
newReturnValues[it.index()] = it.value();
}
}
rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().assign(newYieldedValues);
});
ValueRange newYieldedValuesRange(newYieldedValues);
TypeRange newResultTypes(newYieldedValuesRange);
auto newOp = rewriter.create<vector::MaskOp>(
op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
nullptr,
[](OpBuilder &b, Operation *) {});
newOp.getRegion().takeBody(maskOp.getMaskRegion());
int idx = 0;
for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
if (!newReturnValues[i])
newReturnValues[i] = newOp->getResult(idx++);
}
replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
return success();
}
};
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
vector::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
BufferRelation::Equivalent}};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto yieldOp = cast<vector::YieldOp>(op);
auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
if (!maskOp)
return yieldOp->emitError("unsupported vector::YieldOp parent");
Operation *maskedOp = &maskOp.getMaskRegion().front().front();
if (!options.dynCastBufferizableOp(maskedOp))
return success();
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
newResults.push_back(*maybeBuffer);
} else {
newResults.push_back(value);
}
}
replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
return success();
}
};
}
}
}
void mlir::vector::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
});
}