#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
using namespace mlir;
using namespace mlir::func;
namespace mlir {
namespace sparse_tensor {
static bool containsSparseTensor(TypeRange types) {
for (Type t : types)
if (getSparseTensorEncoding(t))
return true;
return false;
}
class BufferizeDenseOpsPass
: public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
public:
BufferizeDenseOpsPass(
const bufferization::OneShotBufferizationOptions &options)
: PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>>(),
options(options) {}
void runOnOperation() override {
bufferization::OpFilter opFilter;
opFilter.allowOperation([&](Operation *op) {
if (containsSparseTensor(TypeRange(op->getResults())) ||
containsSparseTensor(TypeRange(op->getOperands())))
return false;
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
FunctionType funcType = funcOp.getFunctionType();
if (containsSparseTensor(funcType.getInputs()) ||
containsSparseTensor(funcType.getResults()))
return false;
}
return true;
});
if (failed(bufferization::bufferizeOp(getOperation(), options,
false,
&opFilter)))
signalPassFailure();
}
private:
bufferization::OneShotBufferizationOptions options;
};
}
}
std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options) {
return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
}