#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_GPUASYNCREGIONPASS
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
}
using namespace mlir;
namespace {
class GpuAsyncRegionPass
: public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
struct ThreadTokenCallback;
struct DeferWaitCallback;
struct SingleTokenUseCallback;
void runOnOperation() override;
};
}
static bool isTerminator(Operation *op) {
return op->mightHaveTrait<OpTrait::IsTerminator>();
}
static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
struct GpuAsyncRegionPass::ThreadTokenCallback {
ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
WalkResult operator()(Block *block) {
for (Operation &op : make_early_inc_range(*block)) {
if (failed(visit(&op)))
return WalkResult::interrupt();
}
return WalkResult::advance();
}
private:
LogicalResult visit(Operation *op) {
if (isa<gpu::LaunchOp>(op))
return op->emitOpError("replace with gpu.launch_func first");
if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
if (currentToken)
waitOp.addAsyncDependency(currentToken);
currentToken = waitOp.getAsyncToken();
return success();
}
builder.setInsertionPoint(op);
if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
return rewriteAsyncOp(asyncOp);
if (!currentToken)
return success();
if (isTerminator(op) || hasSideEffects(op))
currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
return success();
}
LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
auto *op = asyncOp.getOperation();
auto tokenType = builder.getType<gpu::AsyncTokenType>();
if (!currentToken)
currentToken = createWaitOp(op->getLoc(), tokenType, {});
asyncOp.addAsyncDependency(currentToken);
currentToken = asyncOp.getAsyncToken();
if (currentToken)
return success();
SmallVector<Type, 1> resultTypes;
resultTypes.reserve(1 + op->getNumResults());
copy(op->getResultTypes(), std::back_inserter(resultTypes));
resultTypes.push_back(tokenType);
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, op->getOperands(),
op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
op->getSuccessors(), op->getNumRegions());
IRMapping mapping;
for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
auto results = newOp->getResults();
currentToken = results.back();
builder.insert(newOp);
op->replaceAllUsesWith(results.drop_back());
op->erase();
return success();
}
Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
return builder.create<gpu::WaitOp>(loc, resultType, operands)
.getAsyncToken();
}
OpBuilder builder;
Value currentToken = {};
};
async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
ValueRange results) {
Operation *yieldOp = executeOp.getBody()->getTerminator();
yieldOp->insertOperands(yieldOp->getNumOperands(), results);
SmallVector<Type, 2> resultTypes;
resultTypes.reserve(executeOp.getNumResults() + results.size());
transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
[](Type type) {
if (auto valueType = dyn_cast<async::ValueType>(type))
return valueType.getValueType();
assert(isa<async::TokenType>(type) && "expected token type");
return type;
});
transform(results, std::back_inserter(resultTypes),
[](Value value) { return value.getType(); });
OpBuilder builder(executeOp);
auto newOp = builder.create<async::ExecuteOp>(
executeOp.getLoc(), TypeRange{resultTypes}.drop_front() ,
executeOp.getDependencies(), executeOp.getBodyOperands());
IRMapping mapper;
newOp.getRegion().getBlocks().clear();
executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
executeOp.getOperation()->replaceAllUsesWith(
newOp.getResults().drop_back(results.size()));
executeOp.erase();
return newOp;
}
struct GpuAsyncRegionPass::DeferWaitCallback {
void operator()(async::ExecuteOp executeOp) {
if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
return;
for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
if (!waitOp.getAsyncToken())
worklist.push_back(waitOp);
return;
}
if (hasSideEffects(&op))
return;
}
}
~DeferWaitCallback() {
for (size_t i = 0; i < worklist.size(); ++i) {
auto waitOp = worklist[i];
auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
waitOp.erase();
executeOp = addExecuteResults(executeOp, dependencies);
auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
executeOp.getToken().user_end());
for (Operation *user : users)
addAsyncDependencyAfter(asyncTokens, user);
}
}
private:
static bool areAllUsersExecuteOrAwait(Value token) {
return !token.use_empty() &&
llvm::all_of(token.getUsers(),
llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
}
void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
OpBuilder builder(op->getContext());
auto loc = op->getLoc();
Block::iterator it;
SmallVector<Value, 1> tokens;
tokens.reserve(asyncTokens.size());
TypeSwitch<Operation *>(op)
.Case<async::AwaitOp>([&](auto awaitOp) {
builder.setInsertionPointAfter(op);
for (auto asyncToken : asyncTokens)
tokens.push_back(
builder.create<async::AwaitOp>(loc, asyncToken).getResult());
it = builder.getInsertionPoint();
})
.Case<async::ExecuteOp>([&](auto executeOp) {
it = executeOp.getBody()->begin();
executeOp.getBodyOperandsMutable().append(asyncTokens);
SmallVector<Type, 1> tokenTypes(
asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
executeOp.getLoc());
copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
std::back_inserter(tokens));
});
it = std::find_if(it, Block::iterator(), [](Operation &op) {
return isTerminator(&op) || hasSideEffects(&op);
});
if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
for (auto token : tokens)
asyncOp.addAsyncDependency(token);
return;
}
builder.setInsertionPoint(it->getBlock(), it);
auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
!it->getNextNode())
worklist.push_back(waitOp);
}
SmallVector<gpu::WaitOp, 8> worklist;
};
struct GpuAsyncRegionPass::SingleTokenUseCallback {
void operator()(async::ExecuteOp executeOp) {
auto multiUseResults = llvm::make_filter_range(
executeOp.getBodyResults(), [](OpResult result) {
if (result.use_empty() || result.hasOneUse())
return false;
auto valueType = dyn_cast<async::ValueType>(result.getType());
return valueType &&
isa<gpu::AsyncTokenType>(valueType.getValueType());
});
if (multiUseResults.empty())
return;
SmallVector<int, 4> indices;
transform(multiUseResults, std::back_inserter(indices),
[](OpResult result) {
return result.getResultNumber() - 1;
});
for (auto index : indices) {
assert(!executeOp.getBodyResults()[index].getUses().empty());
auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
auto count = std::distance(uses.begin(), uses.end());
auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
executeOp = addExecuteResults(executeOp, operands);
uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
auto results = executeOp.getBodyResults().take_back(count);
for (auto pair : llvm::zip(uses, results))
std::get<0>(pair).set(std::get<1>(pair));
}
}
};
void GpuAsyncRegionPass::runOnOperation() {
if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
return signalPassFailure();
getOperation().getRegion().walk(DeferWaitCallback());
getOperation().getRegion().walk(SingleTokenUseCallback());
}
std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
return std::make_unique<GpuAsyncRegionPass>();
}