#include "xsched/levelzero/hal/ze_kernel_arg.h"
#include "xsched/levelzero/hal/ze_assert.h"

using namespace xsched::levelzero;

void ZeKernelGroupSize::Set(ze_kernel_handle_t hKernel)
{
    ZE_ASSERT(Driver::KernelSetGroupSize(hKernel, groupSizeX_, groupSizeY_, groupSizeZ_));
}

ZeKernelArgumentValue::ZeKernelArgumentValue(uint32_t argIndex, size_t argSize, const void *pArgValue)
    : ZeKernelArg(kArgumentValue), argIndex_(argIndex), argSize_(argSize)
{
    if (pArgValue != nullptr) {
        pArgValue_ = malloc(argSize);
        std::memcpy(pArgValue_, pArgValue, argSize);
    }
}

ZeKernelArgumentValue::~ZeKernelArgumentValue()
{
    if (pArgValue_) free(pArgValue_);
}

void ZeKernelArgumentValue::Set(ze_kernel_handle_t hKernel)
{
    ZE_ASSERT(Driver::KernelSetArgumentValue(hKernel, argIndex_, argSize_, pArgValue_));
}

void ZeKernelIndirectAccessFlags::Set(ze_kernel_handle_t hKernel)
{
    ZE_ASSERT(Driver::KernelSetIndirectAccess(hKernel, flags_));
}

void ZeKernelCacheConfigFlags::Set(ze_kernel_handle_t hKernel)
{
    ZE_ASSERT(Driver::KernelSetCacheConfig(hKernel, flags_));
}

void ZeKernelGlobalOffset::Set(ze_kernel_handle_t hKernel)
{
    ZE_ASSERT(Driver::KernelSetGlobalOffsetExp(hKernel, offsetX_, offsetY_, offsetZ_));
}

ze_result_t KernelArgsManager::AddGroupSize(ze_kernel_handle_t hKernel, uint32_t groupSizeX, uint32_t groupSizeY, uint32_t groupSizeZ)
{
    if (group_size_.find(hKernel) == group_size_.end()) {
        group_size_[hKernel] = std::make_shared<std::list<ZeKernelGroupSize>>();
    }
    group_size_[hKernel]->emplace_back(groupSizeX, groupSizeY, groupSizeZ);
    XDEBG("KernelSetGroupSize(kernel: %p) deferred", hKernel);
    return ZE_RESULT_SUCCESS;
}

ze_result_t KernelArgsManager::AddArgumentValue(ze_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize, const void *pArgValue)
{
    if (argument_value_.find(hKernel) == argument_value_.end()) {
        argument_value_[hKernel] = std::make_shared<std::list<ZeKernelArgumentValue>>();
    }
    argument_value_[hKernel]->emplace_back(argIndex, argSize, pArgValue);
    XDEBG("KernelSetArgumentValue(kernel: %p) deferred", hKernel);
    return ZE_RESULT_SUCCESS;
}

ze_result_t KernelArgsManager::AddIndirectAccess(ze_kernel_handle_t hKernel, ze_kernel_indirect_access_flags_t flags)
{
    if (indirect_access_flags_.find(hKernel) == indirect_access_flags_.end()) {
        indirect_access_flags_[hKernel] = std::make_shared<std::list<ZeKernelIndirectAccessFlags>>();
    }
    indirect_access_flags_[hKernel]->emplace_back(flags);
    XDEBG("KernelSetIndirectAccess(kernel: %p) deferred", hKernel);
    return ZE_RESULT_SUCCESS;
}

ze_result_t KernelArgsManager::AddCacheConfig(ze_kernel_handle_t hKernel, ze_cache_config_flags_t flags)
{
    if (cache_config_flags_.find(hKernel) == cache_config_flags_.end()) {
        cache_config_flags_[hKernel] = std::make_shared<std::list<ZeKernelCacheConfigFlags>>();
    }
    cache_config_flags_[hKernel]->emplace_back(flags);
    XDEBG("KernelSetCacheConfig(kernel: %p) deferred", hKernel);
    return ZE_RESULT_SUCCESS;
}

ze_result_t KernelArgsManager::AddGlobalOffsetExp(ze_kernel_handle_t hKernel, uint32_t offsetX, uint32_t offsetY, uint32_t offsetZ)
{
    if (global_offset_.find(hKernel) == global_offset_.end()) {
        global_offset_[hKernel] = std::make_shared<std::list<ZeKernelGlobalOffset>>();
    }
    global_offset_[hKernel]->emplace_back(offsetX, offsetY, offsetZ);
    XDEBG("KernelSetGlobalOffsetExp(kernel: %p) deferred", hKernel);
    return ZE_RESULT_SUCCESS;
}

ze_result_t KernelArgsManager::Set(ze_kernel_handle_t hKernel)
{
    if (frozen_args_.find(hKernel) != frozen_args_.end()) {
        auto &args = frozen_args_[hKernel].front();
        if (args.group_size)
            for (auto &arg : *(args.group_size)) arg.Set(hKernel);
        if (args.argument_value)
            for (auto &arg : *(args.argument_value)) arg.Set(hKernel);
        if (args.indirect_access_flags)
            for (auto &arg : *(args.indirect_access_flags)) arg.Set(hKernel);
        if (args.cache_config_flags)
            for (auto &arg : *(args.cache_config_flags)) arg.Set(hKernel);
        if (args.global_offset)
            for (auto &arg : *(args.global_offset)) arg.Set(hKernel);
        frozen_args_[hKernel].pop_front();
        if (frozen_args_[hKernel].empty()) frozen_args_.erase(hKernel);
    }

    XDEBG("Set kernel %p arguments", hKernel);
    return ZE_RESULT_SUCCESS;
}

void KernelArgsManager::Freeze(ze_kernel_handle_t hKernel)
{
    FrozenArgs args;
    if (group_size_.find(hKernel) != group_size_.end()) {
        args.group_size = group_size_[hKernel];
        group_size_.erase(hKernel);
    }
    if (argument_value_.find(hKernel) != argument_value_.end()) {
        args.argument_value = argument_value_[hKernel];
        argument_value_.erase(hKernel);
    }
    if (indirect_access_flags_.find(hKernel) != indirect_access_flags_.end()) {
        args.indirect_access_flags = indirect_access_flags_[hKernel];
        indirect_access_flags_.erase(hKernel);
    }
    if (cache_config_flags_.find(hKernel) != cache_config_flags_.end()) {
        args.cache_config_flags = cache_config_flags_[hKernel];
        cache_config_flags_.erase(hKernel);
    }
    if (global_offset_.find(hKernel) != global_offset_.end()) {
        args.global_offset = global_offset_[hKernel];
        global_offset_.erase(hKernel);
    }
    frozen_args_[hKernel].emplace_back(std::move(args));

    XDEBG("Freeze kernel %p arguments", hKernel);
}