* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "akg/Dialect/Linalg/IR/LinalgExtOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::linalgExt;
using namespace mlir::bufferization;
namespace {
static LogicalResult
bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
DestinationStyleOpInterface op,
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
if (op.hasPureBufferSemantics())
return success();
if (!op.hasPureTensorSemantics())
return op->emitError() << "op does not have pure tensor semantics";
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(op.getNumDpsInputs());
for (OpOperand *opOperand : op.getDpsInputOperands()) {
if (op.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
}
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
if (failed(buffer))
return failure();
newInputBuffers.push_back(*buffer);
}
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
}
SmallVector<Value> newOperands = newInputBuffers;
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
rewriter.setInsertionPoint(op);
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions(
rewriter, op, TypeRange{}, newOperands));
rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
newOp->getRegion(0).begin());
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
return success();
}
template <typename OpTy>
struct LinalgExtOpBufferizationTait
: public DstBufferizableOpInterfaceExternalModel<LinalgExtOpBufferizationTait<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
return genericOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const {
auto dpsOp = cast<DestinationStyleOpInterface>(op);
return dpsOp.isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const {
return bufferizeDestinationStyleOpInterface(rewriter, cast<DestinationStyleOpInterface>(op), options);
}
};
template <typename... Ops>
struct LinalgExtOpInterfaceHelper {
static void registerOpInterface(MLIRContext *ctx) {
(Ops::template attachInterface<LinalgExtOpBufferizationTait<Ops>>(*ctx), ...);
}
};
}
void mlir::linalgExt::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, linalgExt::LinalgExtDialect *dialect) {
LinalgExtOpInterfaceHelper<
#define GET_OP_LIST
#include "akg/Dialect/Linalg/IR/LinalgExtOps.cpp.inc"
>::registerOpInterface(ctx);
});
}