#include "Utils/CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace sparse_tensor;
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
if (!getSparseTensorEncoding(type)) {
convTypes.push_back(type);
continue;
}
const SparseTensorType stt(cast<RankedTensorType>(type));
foreachFieldAndTypeInSparseTensor(
stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
auto rtp = cast<ShapedType>(t);
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
extraTypes->push_back(rtp);
}
convTypes.push_back(rtp);
}
return true;
});
}
}
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
bool directOut) {
unsigned idx = 0;
for (auto type : types) {
if (!getSparseTensorEncoding(type)) {
toVals.push_back(fromVals[idx++]);
continue;
}
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
if (!isIn)
inputs.push_back(fromVals[idx++]);
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level lv, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else if (directOut) {
Value mem;
if (kind == SparseTensorFieldKind::PosMemRef)
mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
lv);
else if (kind == SparseTensorFieldKind::CrdMemRef)
mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
lv);
else
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
ShapedType rtp = cast<ShapedType>(t);
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
}
}
return true;
});
if (isIn) {
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
toVals.push_back(a.getResult());
} else if (!directOut) {
unsigned len = retTypes.size();
retTypes.append(cntTypes);
auto d =
builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
for (unsigned i = 0; i < len; i++)
toVals.push_back(d.getResult(i));
}
}
}
namespace {
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
SparseFuncAssembler(MLIRContext *context, bool dO)
: OpRewritePattern(context), directOut(dO) {}
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
if (funcOp.isPrivate())
return failure();
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
outputTypes.size() == funcOp.getResultTypes().size())
return failure();
auto orgName = funcOp.getName();
std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
funcOp.setName(wrapper);
funcOp.setPrivate();
Location loc = funcOp.getLoc();
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
MLIRContext *context = modOp.getContext();
OpBuilder moduleBuilder(modOp.getBodyRegion());
unsigned extra = inputTypes.size();
inputTypes.append(extraTypes);
auto func = moduleBuilder.create<func::FuncOp>(
loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
func.setPublic();
OpBuilder::InsertionGuard insertionGuard(rewriter);
Block *body = func.addEntryBlock();
rewriter.setInsertionPointToStart(body);
SmallVector<Value> inputs;
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, 0, true, directOut);
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
SmallVector<Value> outputs;
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
body->getArguments(), outputs, extra, false, directOut);
rewriter.create<func::ReturnOp>(loc, outputs);
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
}
return success();
}
private:
const bool directOut;
};
}
void mlir::populateSparseAssembler(RewritePatternSet &patterns,
bool directOut) {
patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}