#include "mlir/ExecutionEngine/AsyncRuntime.h"
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <functional>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ThreadPool.h"
using namespace mlir::runtime;
namespace mlir {
namespace runtime {
namespace {
class RefCounted;
class AsyncRuntime {
public:
AsyncRuntime() : numRefCountedObjects(0) {}
~AsyncRuntime() {
threadPool.wait();
assert(getNumRefCountedObjects() == 0 &&
"all ref counted objects must be destroyed");
}
int64_t getNumRefCountedObjects() {
return numRefCountedObjects.load(std::memory_order_relaxed);
}
llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
private:
friend class RefCounted;
void addNumRefCountedObjects() {
numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
}
void dropNumRefCountedObjects() {
numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
}
std::atomic<int64_t> numRefCountedObjects;
llvm::DefaultThreadPool threadPool;
};
class State {
public:
enum StateEnum : int8_t {
kUnavailable = 0,
kAvailable = 1,
kError = 2,
};
State(StateEnum s) : state(s) {}
operator StateEnum() { return state; }
bool isUnavailable() const { return state == kUnavailable; }
bool isAvailable() const { return state == kAvailable; }
bool isError() const { return state == kError; }
bool isAvailableOrError() const { return isAvailable() || isError(); }
const char *debug() const {
switch (state) {
case kUnavailable:
return "unavailable";
case kAvailable:
return "available";
case kError:
return "error";
}
}
private:
StateEnum state;
};
class RefCounted {
public:
RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
: runtime(runtime), refCount(refCount) {
runtime->addNumRefCountedObjects();
}
virtual ~RefCounted() {
assert(refCount.load() == 0 && "reference count must be zero");
runtime->dropNumRefCountedObjects();
}
RefCounted(const RefCounted &) = delete;
RefCounted &operator=(const RefCounted &) = delete;
void addRef(int64_t count = 1) { refCount.fetch_add(count); }
void dropRef(int64_t count = 1) {
int64_t previous = refCount.fetch_sub(count);
assert(previous >= count && "reference count should not go below zero");
if (previous == count)
destroy();
}
protected:
virtual void destroy() { delete this; }
private:
AsyncRuntime *runtime;
std::atomic<int64_t> refCount;
};
}
static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
static auto runtime = std::make_unique<AsyncRuntime>();
return runtime;
}
static void resetDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().reset();
}
static AsyncRuntime *getDefaultAsyncRuntime() {
return getDefaultAsyncRuntimeInstance().get();
}
struct AsyncToken : public RefCounted {
AsyncToken(AsyncRuntime *runtime)
: RefCounted(runtime, 2), state(State::kUnavailable) {}
std::atomic<State::StateEnum> state;
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
struct AsyncValue : public RefCounted {
AsyncValue(AsyncRuntime *runtime, int64_t size)
: RefCounted(runtime, 2), state(State::kUnavailable),
storage(size) {}
std::atomic<State::StateEnum> state;
std::vector<std::byte> storage;
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime, int64_t size)
: RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
std::atomic<int> pendingTokens;
std::atomic<int> numErrors;
std::atomic<int> rank;
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->addRef(count);
}
extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->dropRef(count);
}
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
return token;
}
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
return value;
}
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
return group;
}
extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
int rank = group->rank.fetch_add(1);
auto onTokenReady = [group, token]() {
if (State(token->state).isError())
group->numErrors.fetch_add(1);
assert(group->pendingTokens > 0 && "wrong group size");
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
};
if (State(token->state).isAvailableOrError()) {
onTokenReady();
} else {
group->addRef();
token->awaiters.emplace_back([group, onTokenReady]() {
{
std::unique_lock<std::mutex> lockGroup(group->mu);
onTokenReady();
}
group->dropRef();
});
}
return rank;
}
static void setTokenState(AsyncToken *token, State state) {
assert(state.isAvailableOrError() && "must be terminal state");
assert(State(token->state).isUnavailable() && "token must be unavailable");
{
std::unique_lock<std::mutex> lock(token->mu);
token->state = state;
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
}
token->dropRef();
}
static void setValueState(AsyncValue *value, State state) {
assert(state.isAvailableOrError() && "must be terminal state");
assert(State(value->state).isUnavailable() && "value must be unavailable");
{
std::unique_lock<std::mutex> lock(value->mu);
value->state = state;
value->cv.notify_all();
for (auto &awaiter : value->awaiters)
awaiter();
}
value->dropRef();
}
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
setTokenState(token, State::kAvailable);
}
extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
setValueState(value, State::kAvailable);
}
extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
setTokenState(token, State::kError);
}
extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
setValueState(value, State::kError);
}
extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
return State(token->state).isError();
}
extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
return State(value->state).isError();
}
extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
return group->numErrors.load() > 0;
}
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!State(token->state).isAvailableOrError())
token->cv.wait(
lock, [token] { return State(token->state).isAvailableOrError(); });
}
extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
std::unique_lock<std::mutex> lock(value->mu);
if (!State(value->state).isAvailableOrError())
value->cv.wait(
lock, [value] { return State(value->state).isAvailableOrError(); });
}
extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
assert(!State(value->state).isError() && "unexpected error state");
return value->storage.data();
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
auto *runtime = getDefaultAsyncRuntime();
runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
}
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroHandle handle,
CoroResume resume) {
auto execute = [handle, resume]() { (*resume)(handle); };
std::unique_lock<std::mutex> lock(token->mu);
if (State(token->state).isAvailableOrError()) {
lock.unlock();
execute();
} else {
token->awaiters.emplace_back([execute]() { execute(); });
}
}
extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
CoroHandle handle,
CoroResume resume) {
auto execute = [handle, resume]() { (*resume)(handle); };
std::unique_lock<std::mutex> lock(value->mu);
if (State(value->state).isAvailableOrError()) {
lock.unlock();
execute();
} else {
value->awaiters.emplace_back([execute]() { execute(); });
}
}
extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
CoroHandle handle,
CoroResume resume) {
auto execute = [handle, resume]() { (*resume)(handle); };
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens == 0) {
lock.unlock();
execute();
} else {
group->awaiters.emplace_back([execute]() { execute(); });
}
}
extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
}
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
static thread_local std::thread::id thisId = std::this_thread::get_id();
std::cout << "Current thread id: " << thisId << '\n';
}
extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
__mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
assert(exportSymbols.count(name) == 0 && "symbol already exists");
exportSymbols[name] = reinterpret_cast<void *>(ptr);
};
exportSymbol("mlirAsyncRuntimeAddRef",
&mlir::runtime::mlirAsyncRuntimeAddRef);
exportSymbol("mlirAsyncRuntimeDropRef",
&mlir::runtime::mlirAsyncRuntimeDropRef);
exportSymbol("mlirAsyncRuntimeExecute",
&mlir::runtime::mlirAsyncRuntimeExecute);
exportSymbol("mlirAsyncRuntimeGetValueStorage",
&mlir::runtime::mlirAsyncRuntimeGetValueStorage);
exportSymbol("mlirAsyncRuntimeCreateToken",
&mlir::runtime::mlirAsyncRuntimeCreateToken);
exportSymbol("mlirAsyncRuntimeCreateValue",
&mlir::runtime::mlirAsyncRuntimeCreateValue);
exportSymbol("mlirAsyncRuntimeEmplaceToken",
&mlir::runtime::mlirAsyncRuntimeEmplaceToken);
exportSymbol("mlirAsyncRuntimeEmplaceValue",
&mlir::runtime::mlirAsyncRuntimeEmplaceValue);
exportSymbol("mlirAsyncRuntimeSetTokenError",
&mlir::runtime::mlirAsyncRuntimeSetTokenError);
exportSymbol("mlirAsyncRuntimeSetValueError",
&mlir::runtime::mlirAsyncRuntimeSetValueError);
exportSymbol("mlirAsyncRuntimeIsTokenError",
&mlir::runtime::mlirAsyncRuntimeIsTokenError);
exportSymbol("mlirAsyncRuntimeIsValueError",
&mlir::runtime::mlirAsyncRuntimeIsValueError);
exportSymbol("mlirAsyncRuntimeIsGroupError",
&mlir::runtime::mlirAsyncRuntimeIsGroupError);
exportSymbol("mlirAsyncRuntimeAwaitToken",
&mlir::runtime::mlirAsyncRuntimeAwaitToken);
exportSymbol("mlirAsyncRuntimeAwaitValue",
&mlir::runtime::mlirAsyncRuntimeAwaitValue);
exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
exportSymbol("mlirAsyncRuntimeCreateGroup",
&mlir::runtime::mlirAsyncRuntimeCreateGroup);
exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
&mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
&mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
&mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
}
extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
resetDefaultAsyncRuntime();
}
}
}