#ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
#define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_H
#include <cassert>
#include <functional>
#include <list>
#include <mutex>
#include <set>
#include <unordered_map>
#include <vector>
#include "Shared/Debug.h"
#include "Shared/Utils.h"
#include "omptarget.h"
class DeviceAllocatorTy {
public:
virtual ~DeviceAllocatorTy() = default;
virtual void *allocate(size_t Size, void *HstPtr,
TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
virtual int free(void *TgtPtr, TargetAllocTy Kind = TARGET_ALLOC_DEFAULT) = 0;
};
class MemoryManagerTy {
static constexpr const size_t BucketSize[] = {
0, 1U << 2, 1U << 3, 1U << 4, 1U << 5, 1U << 6, 1U << 7,
1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13};
static constexpr const int NumBuckets =
sizeof(BucketSize) / sizeof(BucketSize[0]);
static size_t floorToPowerOfTwo(size_t Num) {
Num |= Num >> 1;
Num |= Num >> 2;
Num |= Num >> 4;
Num |= Num >> 8;
Num |= Num >> 16;
#if INTPTR_MAX == INT64_MAX
Num |= Num >> 32;
#elif INTPTR_MAX == INT32_MAX
#else
#error Unsupported architecture
#endif
Num += 1;
return Num >> 1;
}
static int findBucket(size_t Size) {
const size_t F = floorToPowerOfTwo(Size);
DP("findBucket: Size %zu is floored to %zu.\n", Size, F);
int L = 0, H = NumBuckets - 1;
while (H - L > 1) {
int M = (L + H) >> 1;
if (BucketSize[M] == F)
return M;
if (BucketSize[M] > F)
H = M - 1;
else
L = M;
}
assert(L >= 0 && L < NumBuckets && "L is out of range");
DP("findBucket: Size %zu goes to bucket %d\n", Size, L);
return L;
}
struct NodeTy {
const size_t Size;
void *Ptr;
NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {}
};
struct NodeCmpTy {
bool operator()(const NodeTy &LHS, const NodeTy &RHS) const {
return LHS.Size < RHS.Size;
}
};
using FreeListTy = std::multiset<std::reference_wrapper<NodeTy>, NodeCmpTy>;
std::vector<FreeListTy> FreeLists;
std::vector<std::mutex> FreeListLocks;
std::unordered_map<void *, NodeTy> PtrToNodeTable;
std::mutex MapTableLock;
DeviceAllocatorTy &DeviceAllocator;
size_t SizeThreshold = 1U << 13;
void *allocateOnDevice(size_t Size, void *HstPtr) const {
return DeviceAllocator.allocate(Size, HstPtr, TARGET_ALLOC_DEVICE);
}
int deleteOnDevice(void *Ptr) const { return DeviceAllocator.free(Ptr); }
void *freeAndAllocate(size_t Size, void *HstPtr) {
std::vector<void *> RemoveList;
for (int I = 0; I < NumBuckets; ++I) {
FreeListTy &List = FreeLists[I];
std::lock_guard<std::mutex> Lock(FreeListLocks[I]);
if (List.empty())
continue;
for (const NodeTy &N : List) {
deleteOnDevice(N.Ptr);
RemoveList.push_back(N.Ptr);
}
FreeLists[I].clear();
}
if (!RemoveList.empty()) {
std::lock_guard<std::mutex> LG(MapTableLock);
for (void *P : RemoveList)
PtrToNodeTable.erase(P);
}
return allocateOnDevice(Size, HstPtr);
}
void *allocateOrFreeAndAllocateOnDevice(size_t Size, void *HstPtr) {
void *TgtPtr = allocateOnDevice(Size, HstPtr);
if (TgtPtr == nullptr) {
DP("Failed to get memory on device. Free all memory in FreeLists and "
"try again.\n");
TgtPtr = freeAndAllocate(Size, HstPtr);
}
if (TgtPtr == nullptr)
DP("Still cannot get memory on device probably because the device is "
"OOM.\n");
return TgtPtr;
}
public:
MemoryManagerTy(DeviceAllocatorTy &DeviceAllocator, size_t Threshold = 0)
: FreeLists(NumBuckets), FreeListLocks(NumBuckets),
DeviceAllocator(DeviceAllocator) {
if (Threshold)
SizeThreshold = Threshold;
}
~MemoryManagerTy() {
for (auto Itr = PtrToNodeTable.begin(); Itr != PtrToNodeTable.end();
++Itr) {
assert(Itr->second.Ptr && "nullptr in map table");
deleteOnDevice(Itr->second.Ptr);
}
}
void *allocate(size_t Size, void *HstPtr) {
if (Size == 0)
return nullptr;
DP("MemoryManagerTy::allocate: size %zu with host pointer " DPxMOD ".\n",
Size, DPxPTR(HstPtr));
if (Size > SizeThreshold) {
DP("%zu is greater than the threshold %zu. Allocate it directly from "
"device\n",
Size, SizeThreshold);
void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
DP("Got target pointer " DPxMOD ". Return directly.\n", DPxPTR(TgtPtr));
return TgtPtr;
}
NodeTy *NodePtr = nullptr;
{
const int B = findBucket(Size);
FreeListTy &List = FreeLists[B];
NodeTy TempNode(Size, nullptr);
std::lock_guard<std::mutex> LG(FreeListLocks[B]);
const auto Itr = List.find(TempNode);
if (Itr != List.end()) {
NodePtr = &Itr->get();
List.erase(Itr);
}
}
if (NodePtr != nullptr)
DP("Find one node " DPxMOD " in the bucket.\n", DPxPTR(NodePtr));
if (NodePtr == nullptr) {
DP("Cannot find a node in the FreeLists. Allocate on device.\n");
void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
if (TgtPtr == nullptr)
return nullptr;
{
std::lock_guard<std::mutex> Guard(MapTableLock);
auto Itr = PtrToNodeTable.emplace(TgtPtr, NodeTy(Size, TgtPtr));
NodePtr = &Itr.first->second;
}
DP("Node address " DPxMOD ", target pointer " DPxMOD ", size %zu\n",
DPxPTR(NodePtr), DPxPTR(TgtPtr), Size);
}
assert(NodePtr && "NodePtr should not be nullptr at this point");
return NodePtr->Ptr;
}
int free(void *TgtPtr) {
DP("MemoryManagerTy::free: target memory " DPxMOD ".\n", DPxPTR(TgtPtr));
NodeTy *P = nullptr;
{
std::lock_guard<std::mutex> G(MapTableLock);
auto Itr = PtrToNodeTable.find(TgtPtr);
if (Itr != PtrToNodeTable.end())
P = &Itr->second;
}
if (P == nullptr) {
DP("Cannot find its node. Delete it on device directly.\n");
return deleteOnDevice(TgtPtr);
}
const int B = findBucket(P->Size);
DP("Found its node " DPxMOD ". Insert it to bucket %d.\n", DPxPTR(P), B);
{
std::lock_guard<std::mutex> G(FreeListLocks[B]);
FreeLists[B].insert(*P);
}
return OFFLOAD_SUCCESS;
}
static std::pair<size_t, bool> getSizeThresholdFromEnv() {
static UInt32Envar MemoryManagerThreshold(
"LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD", 0);
size_t Threshold = MemoryManagerThreshold.get();
if (MemoryManagerThreshold.isPresent() && Threshold == 0) {
DP("Disabled memory manager as user set "
"LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0.\n");
return std::make_pair(0, false);
}
return std::make_pair(Threshold, true);
}
};
constexpr const size_t MemoryManagerTy::BucketSize[];
constexpr const int MemoryManagerTy::NumBuckets;
#endif