#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
using namespace linalg;
using namespace mlir::bufferization;
namespace {
static LogicalResult
bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
DestinationStyleOpInterface op,
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
if (op.hasPureBufferSemantics())
return success();
if (!op.hasPureTensorSemantics())
return op->emitError() << "op does not have pure tensor semantics";
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(op.getNumDpsInputs());
for (OpOperand *opOperand : op.getDpsInputOperands()) {
if (op.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
}
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
if (failed(buffer))
return failure();
newInputBuffers.push_back(*buffer);
}
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
}
SmallVector<Value> newOperands = newInputBuffers;
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
rewriter.setInsertionPoint(op);
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
op->getAttrs());
state.addRegion();
Operation *newOp = Operation::create(state);
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
op->getRegion(0).getBlocks());
rewriter.insert(newOp);
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
return success();
}
template <typename OpTy>
struct LinalgOpInterface
: public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
return linalgOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto dpsOp = cast<DestinationStyleOpInterface>(op);
return dpsOp.isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
if (sparse_tensor::hasAnySparseOperand(linalgOp))
return false;
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
assert(linalgOp->getNumOperands() == indexingMaps.size() &&
"unexpected number of indexing maps");
for (auto [operand, map] :
llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
continue;
if (!llvm::is_contained(opOperands, &operand))
continue;
if (!map.isIdentity())
return false;
}
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
return bufferizeDestinationStyleOpInterface(
rewriter, cast<DestinationStyleOpInterface>(op), options);
}
};
template <typename... Ops>
struct LinalgOpInterfaceHelper {
static void registerOpInterface(MLIRContext *ctx) {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
struct SoftmaxOpInterface
: public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
linalg::SoftmaxOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
return &opOperand == &softmaxOp.getInputMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
FailureOr<Value> inputBuffer =
getBuffer(rewriter, softmaxOp.getInput(), options);
if (failed(inputBuffer))
return failure();
FailureOr<Value> outputBuffer =
getBuffer(rewriter, softmaxOp.getOutput(), options);
if (failed(outputBuffer))
return failure();
rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
TypeRange(), *inputBuffer,
*outputBuffer, softmaxOp.getDimension());
replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
return success();
}
};
}
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
LinalgOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}