#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::bufferization;
static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}
BufferizeTypeConverter::BufferizeTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
assert(inputType != type && "expected different types");
auto rankedDestType = type.dyn_cast<MemRefType>();
if (!rankedDestType)
return nullptr;
FailureOr<Value> replacement =
castOrReallocMemRefValue(builder, inputs[0], rankedDestType);
if (failed(replacement))
return nullptr;
return *replacement;
}
if (inputs[0].getType().isa<TensorType>()) {
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}
llvm_unreachable("only tensor/memref input types supported");
});
}
void mlir::bufferization::populateBufferizeMaterializationLegality(
ConversionTarget &target) {
target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
}
namespace {
class BufferizeToTensorOp
: public OpConversionPattern<bufferization::ToTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(bufferization::ToTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getMemref());
return success();
}
};
}
namespace {
class BufferizeToMemrefOp
: public OpConversionPattern<bufferization::ToMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(bufferization::ToMemrefOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, adaptor.getTensor());
return success();
}
};
}
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
patterns.getContext());
}
namespace {
struct FinalizingBufferizePass
: public FinalizingBufferizeBase<FinalizingBufferizePass> {
using FinalizingBufferizeBase<
FinalizingBufferizePass>::FinalizingBufferizeBase;
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
static BufferizationOptions::LayoutMapOption
parseLayoutMapOption(const std::string &s) {
if (s == "fully-dynamic-layout-map")
return BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap;
if (s == "identity-layout-map")
return BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
if (s == "infer-layout-map")
return BufferizationOptions::LayoutMapOption::InferLayoutMap;
llvm_unreachable("invalid layout map option");
}
struct OneShotBufferizePass
: public OneShotBufferizeBase<OneShotBufferizePass> {
OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
: options(options) {}
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
registerAllocationOpInterfaceExternalModels(registry);
}
void runOnOperation() override {
OneShotBufferizationOptions opt;
if (!options) {
opt.allowReturnAllocs = allowReturnAllocs;
opt.allowUnknownOps = allowUnknownOps;
opt.analysisFuzzerSeed = analysisFuzzerSeed;
opt.createDeallocs = createDeallocs;
opt.functionBoundaryTypeConversion =
parseLayoutMapOption(functionBoundaryTypeConversion);
if (mustInferMemorySpace)
opt.defaultMemorySpace = None;
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
auto tensorType = value.getType().cast<TensorType>();
if (unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
assert(
unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
"invalid layout map option");
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
if (this->dialectFilter.hasValue())
return llvm::is_contained(this->dialectFilter,
op->getDialect()->getNamespace());
return true;
};
opt.opFilter.allowOperation(filterFn);
} else {
opt = *options;
}
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
signalPassFailure();
return;
}
} else {
if (failed(runOneShotBufferize(moduleOp, opt))) {
signalPassFailure();
return;
}
}
if (opt.testAnalysisOnly)
return;
OpPassManager cleanupPipeline("builtin.module");
cleanupPipeline.addPass(createCanonicalizerPass());
cleanupPipeline.addPass(createCSEPass());
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
(void)runPipeline(cleanupPipeline, moduleOp);
}
private:
llvm::Optional<OneShotBufferizationOptions> options;
};
}
namespace {
struct BufferizationBufferizePass
: public BufferizationBufferizeBase<BufferizationBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.opFilter.allowDialect<BufferizationDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
}
};
}
std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
return std::make_unique<BufferizationBufferizePass>();
}
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
return std::make_unique<OneShotBufferizePass>();
}
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
const OneShotBufferizationOptions &options) {
return std::make_unique<OneShotBufferizePass>(options);
}
std::unique_ptr<OperationPass<func::FuncOp>>
mlir::bufferization::createFinalizingBufferizePass() {
return std::make_unique<FinalizingBufferizePass>();
}
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
static bool hasTensorSemantics(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
return hasTensorArg || hasTensorResult;
}
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
return hasTensorResult || hasTensorOperand;
}
namespace {
class BufferizationRewriter : public IRRewriter {
public:
BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
DenseSet<Operation *> &toMemrefOps,
SmallVector<Operation *> &worklist,
const BufferizationOptions &options,
const OpFilter *opFilter)
: IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
worklist(worklist), analysisState(options), opFilter(opFilter) {}
protected:
void notifyOperationRemoved(Operation *op) override {
IRRewriter::notifyOperationRemoved(op);
erasedOps.insert(op);
toMemrefOps.erase(op);
}
void notifyOperationInserted(Operation *op) override {
IRRewriter::notifyOperationInserted(op);
erasedOps.erase(op);
if (isa<ToMemrefOp>(op)) {
toMemrefOps.insert(op);
return;
}
if (isa<ToTensorOp>(op))
return;
if (!hasTensorSemantics(op))
return;
auto const &options = analysisState.getOptions();
if (!options.isOpAllowed(op) || (opFilter && !opFilter->isOpAllowed(op)))
return;
#ifndef NDEBUG
for (OpOperand &operand : op->getOpOperands())
if (operand.get().getType().isa<TensorType>())
assert(!analysisState.bufferizesToMemoryWrite(operand) &&
"creating tensor ops that bufferize to a memory write is not "
"allowed during bufferization");
#endif
worklist.push_back(op);
}
private:
DenseSet<Operation *> &erasedOps;
DenseSet<Operation *> &toMemrefOps;
SmallVector<Operation *> &worklist;
const AnalysisState analysisState;
const OpFilter *opFilter;
};
}
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
bool copyBeforeWrite,
const OpFilter *opFilter) {
if (copyBeforeWrite) {
AnalysisState state(options);
if (failed(insertTensorCopies(op, state)))
return failure();
}
DenseSet<Operation *> toMemrefOps;
op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
SmallVector<Operation *> worklist;
op->walk([&](func::FuncOp funcOp) {
if (hasTensorSemantics(funcOp))
worklist.push_back(funcOp);
});
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
worklist.push_back(op);
});
DenseSet<Operation *> erasedOps;
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
worklist, options, opFilter);
for (unsigned i = 0; i < worklist.size(); ++i) {
Operation *op = worklist[i];
if (erasedOps.contains(op))
continue;
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
continue;
if (opFilter && !opFilter->isOpAllowed(op))
continue;
if (!hasTensorSemantics(op))
continue;
rewriter.setInsertionPoint(op);
if (failed(bufferizableOp.bufferize(rewriter, options)))
return op->emitError("failed to bufferize op");
}
for (Operation *op : toMemrefOps) {
rewriter.setInsertionPoint(op);
(void)bufferization::foldToMemrefToTensorPair(rewriter,
cast<ToMemrefOp>(op));
}
if (options.allowUnknownOps)
return success();
for (Operation *op : worklist) {
if (erasedOps.contains(op))
continue;
if (!hasTensorSemantics(op))
continue;
if (!options.isOpAllowed(op))
continue;
if (opFilter && !opFilter->isOpAllowed(op))
continue;
if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
continue;
if (isa<ToTensorOp, ToMemrefOp>(op))
continue;
return op->emitError("op was not bufferized");
}
return success();
}
BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
options.allowUnknownOps = true;
options.createDeallocs = false;
options.enforceAliasingInvariants = false;
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
value.getType().cast<TensorType>(), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>();
return options;
}