#include "Utils/CodegenUtils.h"
#include "Utils/LoopEmitter.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
enum class CuSparseFormat {
kNone,
kCOO,
kCSR,
kCSC,
kBSR,
};
static void markAsGPUContainer(ModuleOp topModule) {
topModule->setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
UnitAttr::get(topModule->getContext()));
}
static gpu::GPUModuleOp genGPUModule(OpBuilder &builder, ModuleOp topModule) {
for (auto op : topModule.getBodyRegion().getOps<gpu::GPUModuleOp>())
return op;
markAsGPUContainer(topModule);
builder.setInsertionPointToStart(&topModule.getBodyRegion().front());
return builder.create<gpu::GPUModuleOp>(topModule->getLoc(),
"sparse_kernels");
}
static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
SmallVectorImpl<Value> &args) {
unsigned kernelNumber = 0;
SmallString<16> kernelName;
do {
kernelName.clear();
("kernel" + Twine(kernelNumber++)).toStringRef(kernelName);
} while (gpuModule.lookupSymbol(kernelName));
builder.setInsertionPointToStart(&gpuModule.getBodyRegion().front());
SmallVector<Type> argsTp;
for (auto arg : args)
argsTp.push_back(arg.getType());
FunctionType type = FunctionType::get(gpuModule->getContext(), argsTp, {});
auto gpuFunc =
builder.create<gpu::GPUFuncOp>(gpuModule->getLoc(), kernelName, type);
gpuFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
return gpuFunc;
}
static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens,
unsigned numThreads) {
Location loc = gpuFunc->getLoc();
Value none = TypedValue<::mlir::IntegerType>{};
Value one = constantIndex(builder, loc, 1);
Value numT = constantIndex(builder, loc, numThreads);
gpu::KernelDim3 gridSize = {one, one, one};
gpu::KernelDim3 blckSize = {numT, one, one};
return builder
.create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
none, args,
builder.getType<gpu::AsyncTokenType>(), tokens)
.getAsyncToken();
}
static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
Value mem) {
MemRefType memTp = cast<MemRefType>(mem.getType());
UnrankedMemRefType resTp =
UnrankedMemRefType::get(memTp.getElementType(), 0);
Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
builder.create<gpu::HostRegisterOp>(loc, cast);
return cast;
}
static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
Value cast) {
builder.create<gpu::HostUnregisterOp>(loc, cast);
}
static Value genFirstWait(OpBuilder &builder, Location loc) {
Type tokenType = builder.getType<gpu::AsyncTokenType>();
return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
.getAsyncToken();
}
static void genBlockingWait(OpBuilder &builder, Location loc,
ValueRange operands) {
builder.create<gpu::WaitOp>(loc, Type(), operands);
}
static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
auto tp = cast<ShapedType>(mem.getType());
auto elemTp = tp.getElementType();
auto shape = tp.getShape();
auto memTp = MemRefType::get(shape, elemTp);
SmallVector<Value> dynamicSizes;
for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
if (shape[r] == ShapedType::kDynamic) {
Value dimOp = linalg::createOrFoldDimOp(builder, loc, mem, r);
dynamicSizes.push_back(dimOp);
}
}
return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
token, dynamicSizes, ValueRange());
}
static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
Value size) {
const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
}
static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
Value size, Value token) {
const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
token, size, ValueRange());
}
static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
Value token) {
return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
}
static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
.getAsyncToken();
}
static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
Value src, Value token) {
return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
.getAsyncToken();
}
static Value genAllocCopy(OpBuilder &builder, Location loc, Value b,
SmallVectorImpl<Value> &tokens) {
Value firstToken = genFirstWait(builder, loc);
auto alloc = genAllocMemRef(builder, loc, b, firstToken);
Value devMem = alloc.getResult(0);
Value depToken = alloc.getAsyncToken();
tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
return devMem;
}
static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
Value tensor) {
auto tensorType = llvm::cast<ShapedType>(tensor.getType());
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
return rewriter.create<bufferization::ToMemrefOp>(loc, memrefType, tensor);
}
static Value genParametersIn(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens,
bool useHostRegistrationForOut) {
Value out;
for (Value s : scalars)
args.push_back(s);
for (Value b : buffers) {
if (useHostRegistrationForOut) {
out = genHostRegisterMemref(builder, loc, b);
args.push_back(b);
useHostRegistrationForOut = false;
continue;
}
args.push_back(genAllocCopy(builder, loc, b, tokens));
}
return out;
}
static void genParametersOut(OpBuilder &builder, Location loc, Value out,
Value kernelToken, SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
SmallVectorImpl<Value> &args,
SmallVectorImpl<Value> &tokens) {
unsigned base = scalars.size();
for (unsigned i = base, e = args.size(); i < e; i++) {
Value firstToken;
if (i == base) {
if (out) {
genHostUnregisterMemref(builder, loc, out);
out = Value();
continue;
}
firstToken =
genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
} else {
firstToken = genFirstWait(builder, loc);
}
tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
}
}
static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
scf::ParallelOp forallOp,
SmallVectorImpl<Value> &constants,
SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers) {
Location loc = gpuFunc->getLoc();
Block &block = gpuFunc.getBody().front();
rewriter.setInsertionPointToStart(&block);
unsigned arg = 0;
IRMapping irMap;
for (Value c : constants)
irMap.map(c, rewriter.clone(*c.getDefiningOp())->getResult(0));
for (Value s : scalars)
irMap.map(s, block.getArgument(arg++));
for (Value b : buffers)
irMap.map(b, block.getArgument(arg++));
Value bid = rewriter.create<gpu::BlockIdOp>(loc, gpu::Dimension::x);
Value bsz = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
Value tid = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value gsz = rewriter.create<gpu::GridDimOp>(loc, gpu::Dimension::x);
Value mul = rewriter.create<arith::MulIOp>(loc, bid, bsz);
Value row = rewriter.create<arith::AddIOp>(loc, mul, tid);
Value inc = rewriter.create<arith::MulIOp>(loc, bsz, gsz);
Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
rewriter.eraseBlock(forOp.getBody());
rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
forOp.getRegion().begin(), irMap);
rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
rewriter.replaceOpWithNewOp<scf::YieldOp>(forOp.getBody()->getTerminator());
rewriter.setInsertionPointAfter(forOp);
rewriter.create<gpu::ReturnOp>(gpuFunc->getLoc());
}
static bool matchAddOfArgs(Block *block, Value val) {
if (auto *def = val.getDefiningOp()) {
if (isa<arith::AddFOp, arith::AddIOp>(def)) {
Value a = block->getArguments()[0];
Value b = block->getArguments()[1];
return (def->getOperand(0) == a && def->getOperand(1) == b) ||
(def->getOperand(0) == b && def->getOperand(1) == a);
}
}
return false;
}
static bool matchMulOfArgs(Block *block, Value val) {
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp, arith::MulIOp>(def)) {
Value a = block->getArguments()[0];
Value b = block->getArguments()[1];
return (def->getOperand(0) == a && def->getOperand(1) == b) ||
(def->getOperand(0) == b && def->getOperand(1) == a);
}
}
return false;
}
static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
if (isa<arith::AddFOp, arith::AddIOp>(def)) {
Value x = op.getBlock()->getArguments()[2];
return (def->getOperand(0) == x &&
matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
(def->getOperand(1) == x &&
matchMulOfArgs(op.getBlock(), def->getOperand(0)));
}
}
return false;
}
static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
Value s_out = op.getBlock()->getArguments()[2];
if (auto redOp =
yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
Value other;
if (s_out == redOp->getOperand(0))
other = redOp->getOperand(1);
else if (s_out == redOp->getOperand(1))
other = redOp->getOperand(0);
else
return false;
if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
return false;
auto yieldUn = cast<sparse_tensor::YieldOp>(
unOp.getRegion(0).front().getTerminator());
auto yieldRed = cast<sparse_tensor::YieldOp>(
redOp.getRegion().front().getTerminator());
return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
}
}
return false;
}
static bool isDenseTensor(Value v) {
auto sTp = getSparseTensorType(v);
return sTp.getDimRank() == sTp.getLvlRank() && sTp.isAllDense();
}
static bool isAdmissibleMetaData(SparseTensorType &aTp) {
return (aTp.getPosWidth() == 0 || aTp.getPosWidth() >= 16) &&
(aTp.getCrdWidth() == 0 || aTp.getCrdWidth() >= 16);
}
static bool isAdmissibleCOO(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
aTp.isSingletonLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
isAdmissibleMetaData(aTp);
}
static bool isAdmissibleCSR(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && aTp.isIdentity() &&
aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) &&
aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
}
static bool isAdmissibleCSC(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 2 && !aTp.isIdentity() &&
aTp.isPermutation() && aTp.isDenseLvl(0) && aTp.isCompressedLvl(1) &&
aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
}
static bool isAdmissibleBSR(SparseTensorType &aTp) {
if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
assert(dims.size() == 2);
return dims[0] == dims[1] && dims[0] > 1;
}
return false;
}
static bool isAdmissible24(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
}
static bool isConversionInto24(Value v) {
if (auto cnv = v.getDefiningOp<ConvertOp>()) {
Value a = cnv.getResult();
Value d = cnv.getSource();
SparseTensorType aTp = getSparseTensorType(a);
return isDenseTensor(d) && isAdmissible24(aTp);
}
return false;
}
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
SparseTensorType bTp,
SparseTensorType cTp, bool enableRT,
bool isMatVec) {
if (bTp.hasEncoding() || cTp.hasEncoding())
return CuSparseFormat::kNone;
if (isAdmissibleCOO(aTp))
#ifdef CUSPARSE_COO_AOS
return isMatVec ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#else
return enableRT ? CuSparseFormat::kCOO : CuSparseFormat::kNone;
#endif
if (isAdmissibleCSR(aTp))
return CuSparseFormat::kCSR;
if (isAdmissibleCSC(aTp))
return CuSparseFormat::kCSC;
if (isAdmissibleBSR(aTp))
return CuSparseFormat::kBSR;
return CuSparseFormat::kNone;
}
static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
CuSparseFormat format, bool enableRT) {
if (format == CuSparseFormat::kCOO) {
if (enableRT)
return builder.create<ToCoordinatesOp>(loc, a, 0);
return builder.create<ToCoordinatesBufferOp>(loc, a);
}
return builder.create<ToPositionsOp>(loc, a, 1);
}
static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
CuSparseFormat format, bool enableRT) {
bool isCOO = format == CuSparseFormat::kCOO;
if (isCOO && !enableRT)
return Value();
return builder.create<ToCoordinatesOp>(loc, a, 1);
}
static Operation *genSpMat(OpBuilder &builder, Location loc,
SparseTensorType &aTp, Type handleTp, Type tokenTp,
Value token, Value sz1, Value sz2, Value nseA,
Value rowA, Value colA, Value valA,
CuSparseFormat format, bool enableRT) {
if (format == CuSparseFormat::kCOO) {
if (enableRT) {
assert(colA);
return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
sz1, sz2, nseA, rowA, colA, valA);
}
#ifdef CUSPARSE_COO_AOS
assert(!colA);
return builder.create<gpu::CreateCooAoSOp>(loc, handleTp, tokenTp, token,
sz1, sz2, nseA, rowA, valA);
#else
llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
#endif
}
assert(colA);
if (format == CuSparseFormat::kCSR)
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
sz2, nseA, rowA, colA, valA);
if (format == CuSparseFormat::kCSC)
return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
sz2, nseA, rowA, colA, valA);
assert(format == CuSparseFormat::kBSR);
SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
assert(dims.size() == 2 && dims[0] == dims[1]);
uint64_t b = dims[0];
Value bSz = constantIndex(builder, loc, b);
Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
Value bNum = builder.create<arith::DivUIOp>(
loc, nseA, constantIndex(builder, loc, b * b));
return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
bCols, bNum, bSz, bSz, rowA, colA,
valA);
}
static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value x = op.getOperand(1);
Value y = op.getOperand(2);
SmallVector<Value> tokens;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType xTp = getSparseTensorType(x);
SparseTensorType yTp = getSparseTensorType(y);
auto format = getCuSparseFormat(aTp, xTp, yTp, enableRT, true);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
Value memX = genTensorToMemref(rewriter, loc, x);
Value vecX = genAllocCopy(rewriter, loc, memX, tokens);
Value memY = genTensorToMemref(rewriter, loc, y);
Value vecY = genAllocCopy(rewriter, loc, memY, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Type indexTp = rewriter.getIndexType();
Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
nseA, rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, vecX, szX);
Value dnX = dvecX.getResult(0);
token = dvecX.getAsyncToken();
auto dvecY = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, vecY, szY);
Value dnY = dvecY.getResult(0);
token = dvecY.getAsyncToken();
auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
loc, indexTp, tokenTp, token, spMatA, dnX, dnY,
dnYType);
Value bufferSz = bufferComp.getResult(0);
token = bufferComp.getAsyncToken();
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();
auto spmvComp = rewriter.create<gpu::SpMVOp>(
loc, tokenTp, token, spMatA, dnX, dnY, dnYType, buffer);
token = spmvComp.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnX)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnY)
.getAsyncToken();
token = genDeallocMemRef(rewriter, loc, rowA, token);
if (colA)
token = genDeallocMemRef(rewriter, loc, colA, token);
token = genDeallocMemRef(rewriter, loc, valA, token);
token = genDeallocMemRef(rewriter, loc, buffer, token);
token = genDeallocMemRef(rewriter, loc, vecX, token);
token = genCopyMemRef(rewriter, loc, memY, vecY, token);
token = genDeallocMemRef(rewriter, loc, vecY, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
return success();
}
static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value b = op.getOperand(1);
Value c = op.getOperand(2);
SmallVector<Value> tokens;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
auto format = getCuSparseFormat(aTp, bTp, cTp, enableRT, false);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kBSR)
return failure();
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
Value bufB = genTensorToMemref(rewriter, loc, b);
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
Value bufC = genTensorToMemref(rewriter, loc, c);
Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Type indexTp = rewriter.getIndexType();
Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
nseA, rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, matB,
SmallVector<Value>{szk, szn});
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();
auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, matC,
SmallVector<Value>{szm, szn});
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();
auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
loc, indexTp, tokenTp, token, spMatA, dnB, dnC,
dmatCType);
Value bufferSz = bufferComp.getResult(0);
token = bufferComp.getAsyncToken();
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
auto spmmComp = rewriter.create<gpu::SpMMOp>(
loc, tokenTp, token, spMatA, dnB, dnC, dnCType, buffer);
token = spmmComp.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
.getAsyncToken();
token = genDeallocMemRef(rewriter, loc, rowA, token);
if (colA)
token = genDeallocMemRef(rewriter, loc, colA, token);
token = genDeallocMemRef(rewriter, loc, valA, token);
token = genDeallocMemRef(rewriter, loc, buffer, token);
token = genDeallocMemRef(rewriter, loc, matB, token);
token = genCopyMemRef(rewriter, loc, bufC, matC, token);
token = genDeallocMemRef(rewriter, loc, matC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
return success();
}
static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value b = op.getOperand(1);
Value c = op.getOperand(2);
SmallVector<Value> tokens;
auto format = CuSparseFormat::kCSR;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
return failure();
auto dnCType = cTp.getElementType();
Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT);
Value amemV = rewriter.create<ToValuesOp>(loc, a);
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT);
Value bmemV = rewriter.create<ToValuesOp>(loc, b);
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Type indexTp = rewriter.getIndexType();
Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
Operation *spGenA =
genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
nseA, rowA, colA, valA, format, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
Operation *spGenB =
genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
nseB, rowB, colB, valB, format, enableRT);
Value spMatB = spGenB->getResult(0);
token = spGenB->getResult(1);
Value zero = constantIndex(rewriter, loc, 0);
Value one = constantIndex(rewriter, loc, 1);
Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
Value rowC = e1.getResult(0);
token = e1.getAsyncToken();
auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
Value colC = e2.getResult(0);
token = e2.getAsyncToken();
auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
Value valC = e3.getResult(0);
token = e3.getAsyncToken();
Operation *spGenC =
genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
zero, rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
Operation *descOp =
rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
Value desc = descOp->getResult(0);
token = descOp->getResult(1);
Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
Value bufferSz1 = work1->getResult(0);
token = work1->getResult(1);
auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
Value buffer1 = buf1.getResult(0);
token = buf1.getAsyncToken();
Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
bufferSz1, buffer1,
gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
token = work2->getResult(1);
Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
Value bufferSz2 = compute1->getResult(0);
token = compute1->getResult(1);
auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
Value buffer2 = buf2.getResult(0);
token = buf2.getAsyncToken();
Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
token = compute2->getResult(1);
Operation *sizes = rewriter.create<gpu::SpMatGetSizeOp>(
loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
Value nnz = sizes->getResult(2);
token = sizes->getResult(3);
auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
colC = a2.getResult(0);
token = a2.getAsyncToken();
auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
valC = a3.getResult(0);
token = a3.getAsyncToken();
Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
loc, tokenTp, token, spMatC, rowC, colC, valC);
token = update->getResult(0);
Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
token = copy->getResult(0);
Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
.getAsyncToken();
token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
token = genCopyMemRef(rewriter, loc, colH, colC, token);
token = genCopyMemRef(rewriter, loc, valH, valC, token);
token = genDeallocMemRef(rewriter, loc, rowA, token);
token = genDeallocMemRef(rewriter, loc, colA, token);
token = genDeallocMemRef(rewriter, loc, valA, token);
token = genDeallocMemRef(rewriter, loc, rowB, token);
token = genDeallocMemRef(rewriter, loc, colB, token);
token = genDeallocMemRef(rewriter, loc, valB, token);
token = genDeallocMemRef(rewriter, loc, rowC, token);
token = genDeallocMemRef(rewriter, loc, colC, token);
token = genDeallocMemRef(rewriter, loc, valC, token);
token = genDeallocMemRef(rewriter, loc, buffer1, token);
token = genDeallocMemRef(rewriter, loc, buffer2, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
vt);
return success();
}
static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
linalg::GenericOp op) {
Location loc = op.getLoc();
Value A = op.getOperand(0);
Value B = op.getOperand(1);
Value C = op.getOperand(2);
SmallVector<Value> tokens;
auto cnv = A.getDefiningOp<ConvertOp>();
assert(cnv);
A = cnv.getSource();
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
return failure();
Value bufA = genTensorToMemref(rewriter, loc, A);
Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
Value bufB = genTensorToMemref(rewriter, loc, B);
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
Value bufC = genTensorToMemref(rewriter, loc, C);
Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
Type indexTp = rewriter.getIndexType();
Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
loc, spMatHandleTp, tokenTp, token, szm, szk,
gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, matB,
SmallVector<Value>{szk, szn});
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();
auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnTensorHandleTp, tokenTp, token, matC,
SmallVector<Value>{szm, szn});
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();
auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
TypeRange bufferTypes(bufferTypes_);
auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
dmatCType);
token = bufferComp.getAsyncToken();
Value bufferSz1 = bufferComp.getResult(0);
auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
Value buffer1 = buf1.getResult(0);
token = buf1.getAsyncToken();
Value bufferSz2 = bufferComp.getResult(1);
auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
Value buffer2 = buf2.getResult(0);
token = buf2.getAsyncToken();
Value bufferSz3 = bufferComp.getResult(2);
auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
Value buffer3 = buf3.getResult(0);
token = buf3.getAsyncToken();
auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
auto spmmComp = rewriter.create<gpu::SpMMOp>(
loc, tokenTp, token, spMatA, dnB, dnC, dnCType,
SmallVector<Value>{buffer1, buffer2, buffer3});
token = spmmComp.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
.getAsyncToken();
SmallVector<Value> newDynamicSizes;
token = genDeallocMemRef(rewriter, loc, buffer1, token);
token = genDeallocMemRef(rewriter, loc, buffer2, token);
token = genDeallocMemRef(rewriter, loc, buffer3, token);
token = genDeallocMemRef(rewriter, loc, matA, token);
token = genDeallocMemRef(rewriter, loc, matB, token);
token = genCopyMemRef(rewriter, loc, bufC, matC, token);
token = genDeallocMemRef(rewriter, loc, matC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
return success();
}
static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
linalg::GenericOp op, bool enableRT) {
Location loc = op.getLoc();
Value a = op.getOperand(0);
Value b = op.getOperand(1);
Value c = op.getOperand(2);
SmallVector<Value> tokens;
SparseTensorType aTp = getSparseTensorType(a);
SparseTensorType bTp = getSparseTensorType(b);
SparseTensorType cTp = getSparseTensorType(c);
auto format = getCuSparseFormat(cTp, bTp, aTp, enableRT, false);
if (format == CuSparseFormat::kNone || format == CuSparseFormat::kCOO ||
format == CuSparseFormat::kCSC)
return failure();
Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value bufA = genTensorToMemref(rewriter, loc, a);
Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
Value bufB = genTensorToMemref(rewriter, loc, b);
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT);
Value memV = rewriter.create<ToValuesOp>(loc, c);
Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valC = genAllocCopy(rewriter, loc, memV, tokens);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Type indexTp = rewriter.getIndexType();
Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnMatHandleTp, tokenTp, token, matA, SmallVector<Value>{szm, szk});
Value dnA = dmatA.getResult(0);
token = dmatA.getAsyncToken();
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
loc, dnMatHandleTp, tokenTp, token, matB, SmallVector<Value>{szk, szn});
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();
Operation *spGenC =
genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
nseC, rowC, colC, valC, format, enableRT);
Value spMatC = spGenC->getResult(0);
token = spGenC->getResult(1);
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
Value bufferSz = bufferComp.getResult(0);
token = bufferComp.getAsyncToken();
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
Value buffer = buf.getResult(0);
token = buf.getAsyncToken();
auto sddmmComp = rewriter.create<gpu::SDDMMOp>(loc, tokenTp, token, dnA, dnB,
spMatC, dnCType, buffer);
token = sddmmComp.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
.getAsyncToken();
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
.getAsyncToken();
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
.getAsyncToken();
token = genDeallocMemRef(rewriter, loc, buffer, token);
token = genDeallocMemRef(rewriter, loc, matA, token);
token = genDeallocMemRef(rewriter, loc, matB, token);
token = genDeallocMemRef(rewriter, loc, rowC, token);
if (colC)
token = genDeallocMemRef(rewriter, loc, colC, token);
token = genCopyMemRef(rewriter, loc, memV, valC, token);
token = genDeallocMemRef(rewriter, loc, valC, token);
tokens.push_back(token);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
return success();
}
struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
ForallRewriter(MLIRContext *context, unsigned nT)
: OpRewritePattern(context), numThreads(nT){};
LogicalResult matchAndRewrite(scf::ParallelOp forallOp,
PatternRewriter &rewriter) const override {
if (!forallOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()) ||
forallOp.getNumReductions() != 0 || forallOp.getNumLoops() != 1 ||
!matchPattern(forallOp.getLowerBound()[0], m_Zero()) ||
!matchPattern(forallOp.getStep()[0], m_One()))
return failure();
SetVector<Value> invariants;
forallOp->walk([&](Operation *op) {
for (OpOperand &o : op->getOpOperands()) {
Value val = o.get();
Block *block;
if (auto arg = dyn_cast<BlockArgument>(val))
block = arg.getOwner();
else
block = val.getDefiningOp()->getBlock();
if (!forallOp.getRegion().findAncestorBlockInRegion(*block))
invariants.insert(val);
}
});
SmallVector<Value> constants;
SmallVector<Value> scalars;
SmallVector<Value> buffers;
for (Value val : invariants) {
Type tp = val.getType();
if (val.getDefiningOp<arith::ConstantOp>())
constants.push_back(val);
else if (isa<FloatType>(tp) || tp.isIntOrIndex())
scalars.push_back(val);
else if (isa<MemRefType>(tp))
buffers.push_back(val);
else
return failure();
}
Location loc = forallOp->getLoc();
SmallVector<Value> args;
SmallVector<Value> tokens;
Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
false);
auto saveIp = rewriter.saveInsertionPoint();
ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
auto gpuModule = genGPUModule(rewriter, topModule);
auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
rewriter.restoreInsertionPoint(saveIp);
genBlockingWait(rewriter, loc, tokens);
tokens.clear();
Value kernelToken =
genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
tokens);
genBlockingWait(rewriter, loc, tokens);
rewriter.eraseOp(forallOp);
return success();
}
private:
unsigned numThreads;
};
struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LinalgOpRewriter(MLIRContext *context, bool rt)
: OpRewritePattern(context), enableRT(rt) {}
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
if (op.getNumDpsInits() != 1)
return failure();
const unsigned numLoops = op.getNumLoops();
const unsigned numTensors = op->getNumOperands();
const auto iteratorTypes = op.getIteratorTypesArray();
SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, op.getContext());
};
AffineExpr i, j, k;
bindDims(getContext(), i, j, k);
if (numLoops == 2 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isReductionIterator(iteratorTypes[1]) &&
maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMV(rewriter, op, enableRT);
}
if (numLoops == 3 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
return rewriteSpGEMM(rewriter, op, enableRT);
if (isConversionInto24(op.getOperand(0)))
return rewrite2To4SpMM(rewriter, op);
return rewriteSpMM(rewriter, op, enableRT);
}
if (numLoops == 3 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
maps == infer({{i, k}, {k, j}, {i, j}}) &&
matchSumReductionOfMulUnary(op)) {
return rewriteSDDMM(rewriter, op, enableRT);
}
return failure();
}
private:
bool enableRT;
};
}
void mlir::populateSparseGPUCodegenPatterns(RewritePatternSet &patterns,
unsigned numThreads) {
patterns.add<ForallRewriter>(patterns.getContext(), numThreads);
}
void mlir::populateSparseGPULibgenPatterns(RewritePatternSet &patterns,
bool enableRT) {
patterns.add<LinalgOpRewriter>(patterns.getContext(), enableRT);
}