#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
#include "mlir/Dialect/Async/Passes.h.inc"
}
#define DEBUG_TYPE "async-ref-counting"
using namespace mlir;
using namespace mlir::async;
namespace {
class AsyncRuntimeRefCountingOptPass
: public impl::AsyncRuntimeRefCountingOptBase<
AsyncRuntimeRefCountingOptPass> {
public:
AsyncRuntimeRefCountingOptPass() = default;
void runOnOperation() override;
private:
LogicalResult optimizeReferenceCounting(
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
};
}
LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
Region *definingRegion = value.getParentRegion();
struct BlockUsersInfo {
llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
llvm::SmallVector<Operation *, 4> users;
};
llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
auto updateBlockUsersInfo = [&](Operation *user) {
BlockUsersInfo &info = blockUsers[user->getBlock()];
info.users.push_back(user);
if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
info.addRefs.push_back(addRef);
if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
info.dropRefs.push_back(dropRef);
};
for (Operation *user : value.getUsers()) {
while (user->getParentRegion() != definingRegion) {
updateBlockUsersInfo(user);
user = user->getParentOp();
assert(user != nullptr && "value user lies outside of the value region");
}
updateBlockUsersInfo(user);
}
auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
return a->isBeforeInBlock(b);
};
llvm::sort(info.addRefs, isBeforeInBlock);
llvm::sort(info.dropRefs, isBeforeInBlock);
llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
return isBeforeInBlock(a, b);
});
return info;
};
for (auto &kv : blockUsers) {
BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
for (RuntimeAddRefOp addRef : info.addRefs) {
for (RuntimeDropRefOp dropRef : info.dropRefs) {
if (dropRef.getCount() != addRef.getCount() ||
dropRef->isBeforeInBlock(addRef.getOperation()))
continue;
Operation *firstFunctionCallUser = nullptr;
Operation *lastNonFunctionCallUser = nullptr;
for (Operation *user : info.users) {
if (user == addRef || user->isBeforeInBlock(addRef))
continue;
if (user == dropRef || dropRef->isBeforeInBlock(user))
break;
Operation *functionCall = dyn_cast<func::CallOp>(user);
if (functionCall &&
(!firstFunctionCallUser ||
functionCall->isBeforeInBlock(firstFunctionCallUser))) {
firstFunctionCallUser = functionCall;
continue;
}
if (!functionCall &&
(!lastNonFunctionCallUser ||
lastNonFunctionCallUser->isBeforeInBlock(user))) {
lastNonFunctionCallUser = user;
continue;
}
}
if (firstFunctionCallUser && lastNonFunctionCallUser &&
firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
continue;
auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
addRef.getOperation());
if (!emplaced.second)
continue;
if (emplaced.second)
break;
}
}
}
return success();
}
void AsyncRuntimeRefCountingOptPass::runOnOperation() {
Operation *op = getOperation();
llvm::SmallDenseMap<Operation *, Operation *> cancellable;
WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
for (BlockArgument arg : block->getArguments())
if (isRefCounted(arg.getType()))
if (failed(optimizeReferenceCounting(arg, cancellable)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (blockWalk.wasInterrupted())
signalPassFailure();
WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
for (unsigned i = 0; i < op->getNumResults(); ++i)
if (isRefCounted(op->getResultTypes()[i]))
if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (opWalk.wasInterrupted())
signalPassFailure();
LLVM_DEBUG({
llvm::dbgs() << "Found " << cancellable.size()
<< " cancellable reference counting operations\n";
});
for (auto &kv : cancellable) {
kv.first->erase();
kv.second->erase();
}
}
std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
return std::make_unique<AsyncRuntimeRefCountingOptPass>();
}