#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
namespace mlir {
struct ScfToSPIRVContextImpl {
DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
};
}
ScfToSPIRVContext::ScfToSPIRVContext() {
impl = std::make_unique<::ScfToSPIRVContextImpl>();
}
ScfToSPIRVContext::~ScfToSPIRVContext() = default;
namespace {
template <typename ScfOp, typename OpTy>
void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
ConversionPatternRewriter &rewriter,
ScfToSPIRVContextImpl *scfToSPIRVContext,
ArrayRef<Type> returnTypes) {
Location loc = scfOp.getLoc();
auto &allocas = scfToSPIRVContext->outputVars[newOp];
allocas.clear();
SmallVector<Value, 8> resultValue;
for (Type convertedType : returnTypes) {
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
auto alloc = rewriter.create<spirv::VariableOp>(
loc, pointerType, spirv::StorageClass::Function,
nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
}
Region::iterator getBlockIt(Region ®ion, unsigned index) {
return std::next(region.begin(), index);
}
template <typename OpTy>
class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
public:
SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
: OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
protected:
ScfToSPIRVContextImpl *scfToSPIRVContext;
SPIRVTypeConverter &typeConverter;
};
struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
using SCFToSPIRVPattern::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = forOp.getLoc();
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
OpBuilder::InsertionGuard guard(rewriter);
Block *header = rewriter.createBlock(&loopOp.getBody(),
getBlockIt(loopOp.getBody(), 1));
rewriter.setInsertionPointAfter(loopOp);
Value adapLowerBound = adaptor.getLowerBound();
BlockArgument newIndVar =
header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
for (Value arg : adaptor.getInitArgs())
header->addArgument(arg.getType(), arg.getLoc());
Block *body = forOp.getBody();
TypeConverter::SignatureConversion signatureConverter(
body->getNumArguments());
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
signatureConverter);
rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
getBlockIt(loopOp.getBody(), 2));
SmallVector<Value, 8> args(1, adaptor.getLowerBound());
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
rewriter.create<spirv::BranchOp>(loc, header, args);
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
rewriter.create<spirv::BranchConditionalOp>(
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
Block *continueBlock = loopOp.getContinueBlock();
rewriter.setInsertionPointToEnd(continueBlock);
Value updatedIndVar = rewriter.create<spirv::IAddOp>(
loc, newIndVar.getType(), newIndVar, adaptor.getStep());
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
SmallVector<Type, 8> initTypes;
for (auto arg : adaptor.getInitArgs())
initTypes.push_back(arg.getType());
replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
initTypes);
return success();
}
};
struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
using SCFToSPIRVPattern::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = ifOp.getLoc();
auto selectionOp =
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
rewriter.create<spirv::MergeOp>(loc);
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
rewriter.createBlock(&selectionOp.getBody().front());
auto &thenRegion = ifOp.getThenRegion();
auto *thenBlock = &thenRegion.front();
rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
auto *elseBlock = mergeBlock;
if (!ifOp.getElseRegion().empty()) {
auto &elseRegion = ifOp.getElseRegion();
elseBlock = &elseRegion.front();
rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
}
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
thenBlock, ArrayRef<Value>(),
elseBlock, ArrayRef<Value>());
SmallVector<Type, 8> returnTypes;
for (auto result : ifOp.getResults()) {
auto convertedType = typeConverter.convertType(result.getType());
if (!convertedType)
return rewriter.notifyMatchFailure(
loc,
llvm::formatv("failed to convert type '{0}'", result.getType()));
returnTypes.push_back(convertedType);
}
replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
returnTypes);
return success();
}
};
struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
public:
using SCFToSPIRVPattern::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ValueRange operands = adaptor.getOperands();
Operation *parent = terminatorOp->getParentOp();
if (parent->getDialect()->getNamespace() ==
scf::SCFDialect::getDialectNamespace() &&
!isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
return rewriter.notifyMatchFailure(
terminatorOp,
llvm::formatv("conversion not supported for parent op: '{0}'",
parent->getName()));
if (!operands.empty()) {
auto &allocas = scfToSPIRVContext->outputVars[parent];
if (allocas.size() != operands.size())
return failure();
auto loc = terminatorOp.getLoc();
for (unsigned i = 0, e = operands.size(); i < e; i++)
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(parent)) {
auto br = cast<spirv::BranchOp>(
rewriter.getInsertionBlock()->getTerminator());
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
args);
rewriter.eraseOp(br);
}
}
rewriter.eraseOp(terminatorOp);
return success();
}
};
struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
using SCFToSPIRVPattern::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = whileOp.getLoc();
auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
loopOp.addEntryAndMergeBlock(rewriter);
Region &beforeRegion = whileOp.getBefore();
Region &afterRegion = whileOp.getAfter();
if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
return rewriter.notifyMatchFailure(whileOp,
"Failed to convert region types");
OpBuilder::InsertionGuard guard(rewriter);
Block &entryBlock = *loopOp.getEntryBlock();
Block &beforeBlock = beforeRegion.front();
Block &afterBlock = afterRegion.front();
Block &mergeBlock = *loopOp.getMergeBlock();
auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
SmallVector<Value> condArgs;
if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
return failure();
Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
if (!conditionVal)
return failure();
auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
SmallVector<Value> yieldArgs;
if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
return failure();
rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
getBlockIt(loopOp.getBody(), 1));
rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
getBlockIt(loopOp.getBody(), 2));
rewriter.setInsertionPointToEnd(&entryBlock);
rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
auto condLoc = cond.getLoc();
SmallVector<Value> resultValues(condArgs.size());
for (const auto &it : llvm::enumerate(condArgs)) {
auto res = it.value();
auto i = it.index();
auto pointerType =
spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
rewriter.setInsertionPoint(loopOp);
auto alloc = rewriter.create<spirv::VariableOp>(
condLoc, pointerType, spirv::StorageClass::Function,
nullptr);
rewriter.setInsertionPointAfter(loopOp);
auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
resultValues[i] = loadResult;
rewriter.setInsertionPointToEnd(&beforeBlock);
rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
}
rewriter.setInsertionPointToEnd(&beforeBlock);
rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
rewriter.setInsertionPointToEnd(&afterBlock);
rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
yieldArgs);
rewriter.replaceOp(whileOp, resultValues);
return success();
}
};
}
void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
RewritePatternSet &patterns) {
patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
WhileOpConversion>(patterns.getContext(), typeConverter,
scfToSPIRVContext.getImpl());
}