#include "xsched/levelzero/hal/pool.h"

using namespace xsched::levelzero;

#define POOL_SIZE 16384

EventPool::EventPool(ze_context_handle_t ctx): kContext(ctx)
{
    ze_event_pool_desc_t event_pool_desc = {
        .stype = ZE_STRUCTURE_TYPE_EVENT_POOL_DESC,
        .pNext = nullptr,
        .flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE,
        .count = POOL_SIZE,
    };
    ZE_ASSERT(Driver::EventPoolCreate(kContext, &event_pool_desc, 0, nullptr, &event_pool_));
}

EventPool::~EventPool()
{
    ZE_ASSERT(Driver::EventPoolDestroy(event_pool_));
}

std::shared_ptr<EventPool> EventPool::Instance(ze_context_handle_t ctx)
{
    static std::mutex mtx;
    static std::map<ze_context_handle_t, std::shared_ptr<EventPool>> pools;

    std::lock_guard<std::mutex> lock(mtx);
    auto it = pools.find(ctx);
    if (it != pools.end()) return it->second;

    auto event_pool = std::make_shared<EventPool>(ctx);
    pools[ctx] = event_pool;
    return event_pool;
}

void *EventPool::Create()
{
    static uint32_t event_count = 0;
    if(event_count >= POOL_SIZE) XERRO("event count exceeds limit (%d)", POOL_SIZE);
    ze_event_handle_t event;
    static const ze_event_desc_t event_desc = {
        .stype = ZE_STRUCTURE_TYPE_EVENT_DESC,
        .pNext = nullptr,
        .index = event_count++,
        .signal = ZE_EVENT_SCOPE_FLAG_HOST,
        .wait = ZE_EVENT_SCOPE_FLAG_HOST,
    };
    ZE_ASSERT(Driver::EventCreate(event_pool_, &event_desc, &event));
    return event;
}

std::mutex FencePool::mtx_;
std::map<ze_command_queue_handle_t, std::shared_ptr<FencePool>> FencePool::pools_;

std::shared_ptr<FencePool> FencePool::Instance(ze_command_queue_handle_t cmdq)
{
    std::unique_lock<std::mutex> lock(mtx_);
    auto it = pools_.find(cmdq);
    if (it != pools_.end()) return it->second;

    auto fence_pool = std::make_shared<FencePool>(cmdq);
    pools_[cmdq] = fence_pool;
    return fence_pool;
}

void FencePool::Destroy(ze_command_queue_handle_t cmdq)
{
    std::lock_guard<std::mutex> lock(mtx_);
    pools_.erase(cmdq);
}

void *FencePool::Create()
{
    ze_fence_handle_t fence;
    static const ze_fence_desc_t fence_desc = {
        .stype = ZE_STRUCTURE_TYPE_FENCE_DESC,
        .pNext = nullptr,
        .flags = !ZE_FENCE_FLAG_SIGNALED,
    };
    ZE_ASSERT(Driver::FenceCreate(kCmdq, &fence_desc, &fence));
    return fence;
}

void FencePool::Destroy(void *fence)
{
    ZE_ASSERT(Driver::FenceDestroy((ze_fence_handle_t)fence));
}

std::shared_ptr<CommandListPool> CommandListPool::Instance(ze_command_list_handle_t cmd_list)
{
    ze_context_handle_t ctx;
    ze_device_handle_t dev;
    ZE_ASSERT(Driver::CommandListGetContextHandle(cmd_list, &ctx));
    ZE_ASSERT(Driver::CommandListGetDeviceHandle(cmd_list, &dev));
    return Instance(ctx, dev);
}

std::shared_ptr<CommandListPool> CommandListPool::Instance(ze_context_handle_t ctx, ze_device_handle_t dev)
{
    static std::mutex mtx;
    static std::map<ze_context_handle_t,
           std::map<ze_device_handle_t, std::shared_ptr<CommandListPool>>> pools;

    std::lock_guard<std::mutex> lock(mtx);
    auto ctx_it = pools.find(ctx);
    if (ctx_it == pools.end()) {
        ctx_it = pools.emplace(ctx, std::map<ze_device_handle_t, std::shared_ptr<CommandListPool>>()).first;
    }

    auto dev_it = ctx_it->second.find(dev);
    if (dev_it != ctx_it->second.end()) return dev_it->second;

    auto cmd_list_pool = std::make_shared<CommandListPool>(ctx, dev);
    ctx_it->second[dev] = cmd_list_pool;
    return cmd_list_pool;
}

void *CommandListPool::Create()
{
    ze_command_list_handle_t cmd_list;
    ze_command_list_desc_t desc = {
        .stype = ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
        .pNext = nullptr,
        .commandQueueGroupOrdinal = 0,
        .flags = 0,
    };
    ZE_ASSERT(Driver::CommandListCreate(kContext, kDevice, &desc, &cmd_list));
    return cmd_list;
}