#include <utility>
#include "mlir/Dialect/Async/Passes.h"
#include "PassDetail.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
#include "mlir/Dialect/Async/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::async;
#define DEBUG_TYPE "async-to-async-runtime"
static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
namespace {
class AsyncToAsyncRuntimePass
: public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
public:
AsyncToAsyncRuntimePass() = default;
void runOnOperation() override;
};
}
namespace {
class AsyncFuncToAsyncRuntimePass
: public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
public:
AsyncFuncToAsyncRuntimePass() = default;
void runOnOperation() override;
};
}
namespace {
struct CoroMachinery {
func::FuncOp func;
std::optional<Value> asyncToken;
llvm::SmallVector<Value, 4> returnValues;
Value coroHandle;
Block *entry;
std::optional<Block *> setError;
Block *cleanup;
// Resume Destroy (duplicate of Cleanup)
Block *cleanupForDestroy;
Block *suspend;
};
}
using FuncCoroMapPtr =
std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
static CoroMachinery setupCoroMachinery(func::FuncOp func) {
assert(!func.getBlocks().empty() && "Function must have an entry block");
MLIRContext *ctx = func.getContext();
Block *entryBlock = &func.getBlocks().front();
Block *originalEntryBlock =
entryBlock->splitBlock(entryBlock->getOperations().begin());
auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
bool isStateful = isa<TokenType>(func.getResultTypes().front());
std::optional<Value> retToken;
if (isStateful)
retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
llvm::SmallVector<Value, 4> retValues;
ArrayRef<Type> resValueTypes =
isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
for (auto resType : resValueTypes)
retValues.emplace_back(
builder.create<RuntimeCreateOp>(resType).getResult());
auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
auto coroHdlOp =
builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
builder.create<cf::BranchOp>(originalEntryBlock);
Block *cleanupBlock = func.addBlock();
Block *cleanupBlockForDestroy = func.addBlock();
Block *suspendBlock = func.addBlock();
auto buildCleanupBlock = [&](Block *cb) {
builder.setInsertionPointToStart(cb);
builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
builder.create<cf::BranchOp>(suspendBlock);
};
buildCleanupBlock(cleanupBlock);
buildCleanupBlock(cleanupBlockForDestroy);
builder.setInsertionPointToStart(suspendBlock);
builder.create<CoroEndOp>(coroHdlOp.getHandle());
SmallVector<Value, 4> ret;
if (retToken)
ret.push_back(*retToken);
ret.insert(ret.end(), retValues.begin(), retValues.end());
builder.create<func::ReturnOp>(ret);
func->setAttr("passthrough", builder.getArrayAttr(
StringAttr::get(ctx, "presplitcoroutine")));
CoroMachinery machinery;
machinery.func = func;
machinery.asyncToken = retToken;
machinery.returnValues = retValues;
machinery.coroHandle = coroHdlOp.getHandle();
machinery.entry = entryBlock;
machinery.setError = std::nullopt;
machinery.cleanup = cleanupBlock;
machinery.cleanupForDestroy = cleanupBlockForDestroy;
machinery.suspend = suspendBlock;
return machinery;
}
static Block *setupSetErrorBlock(CoroMachinery &coro) {
if (coro.setError)
return *coro.setError;
coro.setError = coro.func.addBlock();
(*coro.setError)->moveBefore(coro.cleanup);
auto builder =
ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
if (coro.asyncToken)
builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
for (Value retValue : coro.returnValues)
builder.create<RuntimeSetErrorOp>(retValue);
builder.create<cf::BranchOp>(coro.cleanup);
return *coro.setError;
}
static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
ModuleOp module = execute->getParentOfType<ModuleOp>();
MLIRContext *ctx = module.getContext();
Location loc = execute.getLoc();
cloneConstantsIntoTheRegion(execute.getBodyRegion());
SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
execute.getDependencies().end());
functionInputs.insert(execute.getBodyOperands().begin(),
execute.getBodyOperands().end());
getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
auto typesRange = llvm::map_range(
functionInputs, [](Value value) { return value.getType(); });
SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
auto outputTypes = execute.getResultTypes();
auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
auto funcAttrs = ArrayRef<NamedAttribute>();
func::FuncOp func =
func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
symbolTable.insert(func);
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
{
size_t numDependencies = execute.getDependencies().size();
size_t numOperands = execute.getBodyOperands().size();
for (size_t i = 0; i < numDependencies; ++i)
builder.create<AwaitOp>(func.getArgument(i));
SmallVector<Value, 4> unwrappedOperands(numOperands);
for (size_t i = 0; i < numOperands; ++i) {
Value operand = func.getArgument(numDependencies + i);
unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
}
IRMapping valueMapping;
valueMapping.map(functionInputs, func.getArguments());
valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
for (Operation &op : execute.getBodyRegion().getOps())
builder.clone(op, valueMapping);
}
CoroMachinery coro = setupCoroMachinery(func);
{
cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
builder.setInsertionPointToEnd(coro.entry);
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
builder.create<RuntimeResumeOp>(coro.coroHandle);
builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
branch.getDest(), coro.cleanupForDestroy);
branch.erase();
}
{
ImplicitLocOpBuilder callBuilder(loc, execute);
auto callOutlinedFunc = callBuilder.create<func::CallOp>(
func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
execute.erase();
}
return {func, coro};
}
namespace {
class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
op, GroupType::get(op->getContext()), adaptor.getOperands());
return success();
}
};
}
namespace {
class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
op, rewriter.getIndexType(), adaptor.getOperands());
return success();
}
};
}
namespace {
class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
public:
AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
LogicalResult
matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto newFuncOp =
rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
SymbolTable::setSymbolVisibility(newFuncOp,
SymbolTable::getSymbolVisibility(op));
for (const auto &namedAttr : op->getAttrs()) {
if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
newFuncOp.end());
CoroMachinery coro = setupCoroMachinery(newFuncOp);
(*coros)[newFuncOp] = coro;
rewriter.eraseOp(op);
return success();
}
private:
FuncCoroMapPtr coros;
};
class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
public:
AsyncCallOpLowering(MLIRContext *ctx)
: OpConversionPattern<async::CallOp>(ctx) {}
LogicalResult
matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::CallOp>(
op, op.getCallee(), op.getResultTypes(), op.getOperands());
return success();
}
};
class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
public:
AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
LogicalResult
matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros->find(func);
if (funcCoro == coros->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
const CoroMachinery &coro = funcCoro->getSecond();
rewriter.setInsertionPointAfter(op);
for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
Value returnValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
}
if (coro.asyncToken)
rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
rewriter.eraseOp(op);
rewriter.create<cf::BranchOp>(loc, coro.cleanup);
return success();
}
private:
FuncCoroMapPtr coros;
};
}
namespace {
template <typename AwaitType, typename AwaitableType>
class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
bool shouldLowerBlockingWait)
: OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
shouldLowerBlockingWait(shouldLowerBlockingWait) {}
LogicalResult
matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<AwaitableType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros->find(func);
const bool isInCoroutine = funcCoro != coros->end();
Location loc = op->getLoc();
Value operand = adaptor.getOperand();
Type i1 = rewriter.getI1Type();
if (!isInCoroutine && !shouldLowerBlockingWait)
return failure();
if (!isInCoroutine) {
ImplicitLocOpBuilder builder(loc, rewriter);
builder.create<RuntimeAwaitOp>(loc, operand);
Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
Value notError = builder.create<arith::XOrIOp>(
isError, builder.create<arith::ConstantOp>(
loc, i1, builder.getIntegerAttr(i1, 1)));
builder.create<cf::AssertOp>(notError,
"Awaited async operand is in error state");
}
if (isInCoroutine) {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
ImplicitLocOpBuilder builder(loc, rewriter);
MLIRContext *ctx = op->getContext();
auto coroSaveOp =
builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
builder.setInsertionPointToEnd(suspended);
builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
coro.cleanupForDestroy);
Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
builder.setInsertionPointToStart(resume);
auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
builder.create<cf::CondBranchOp>(isError,
setupSetErrorBlock(coro),
ArrayRef<Value>(),
continuation,
ArrayRef<Value>());
rewriter.setInsertionPointToStart(continuation);
}
if (Value replaceWith = getReplacementValue(op, operand, rewriter))
rewriter.replaceOp(op, replaceWith);
else
rewriter.eraseOp(op);
return success();
}
virtual Value getReplacementValue(AwaitType op, Value operand,
ConversionPatternRewriter &rewriter) const {
return Value();
}
private:
FuncCoroMapPtr coros;
bool shouldLowerBlockingWait;
};
class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
public:
using Base::Base;
};
class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
public:
using Base::Base;
Value
getReplacementValue(AwaitOp op, Value operand,
ConversionPatternRewriter &rewriter) const override {
auto valueType = cast<ValueType>(operand.getType()).getValueType();
return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
}
};
class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
public:
using Base::Base;
};
}
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
LogicalResult
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros->find(func);
if (funcCoro == coros->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
const CoroMachinery &coro = funcCoro->getSecond();
for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
Value yieldValue = std::get<0>(tuple);
Value asyncValue = std::get<1>(tuple);
rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
}
if (coro.asyncToken)
rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
rewriter.eraseOp(op);
rewriter.create<cf::BranchOp>(loc, coro.cleanup);
return success();
}
private:
FuncCoroMapPtr coros;
};
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public:
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros->find(func);
if (funcCoro == coros->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
Location loc = op->getLoc();
CoroMachinery &coro = funcCoro->getSecond();
Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
rewriter.setInsertionPointToEnd(cont->getPrevNode());
rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
cont,
ArrayRef<Value>(),
setupSetErrorBlock(coro),
ArrayRef<Value>());
rewriter.eraseOp(op);
return success();
}
private:
FuncCoroMapPtr coros;
};
void AsyncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
SymbolTable symbolTable(module);
FuncCoroMapPtr coros =
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
module.walk([&](ExecuteOp execute) {
coros->insert(outlineExecuteOp(symbolTable, execute));
});
LLVM_DEBUG({
llvm::dbgs() << "Outlined " << coros->size()
<< " functions built from async.execute operations\n";
});
auto isInCoroutine = [&](Operation *op) -> bool {
auto parentFunc = op->getParentOfType<func::FuncOp>();
return coros->find(parentFunc) != coros->end();
};
MLIRContext *ctx = module->getContext();
RewritePatternSet asyncPatterns(ctx);
populateSCFToControlFlowConversionPatterns(asyncPatterns);
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
asyncPatterns
.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
ctx, coros, true);
asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
auto walkResult = op->walk([&](Operation *nested) {
bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
: WalkResult::advance();
});
return !walkResult.wasInterrupted();
});
runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
[&](cf::AssertOp op) -> bool {
auto func = op->getParentOfType<func::FuncOp>();
return !coros->contains(func);
});
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
return;
}
}
void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
RewritePatternSet &patterns, ConversionTarget &target) {
FuncCoroMapPtr coros =
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
MLIRContext *ctx = patterns.getContext();
patterns.add<AsyncCallOpLowering>(ctx);
patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
ctx, coros, false);
patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
[coros](Operation *op) {
auto exec = op->getParentOfType<ExecuteOp>();
auto func = op->getParentOfType<func::FuncOp>();
return exec || !coros->contains(func);
});
}
void AsyncFuncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext *ctx = module->getContext();
RewritePatternSet asyncPatterns(ctx);
ConversionTarget runtimeTarget(*ctx);
populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
runtimeTarget);
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
cf::BranchOp, cf::CondBranchOp>();
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
return;
}
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
return std::make_unique<AsyncToAsyncRuntimePass>();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createAsyncFuncToAsyncRuntimePass() {
return std::make_unique<AsyncFuncToAsyncRuntimePass>();
}