#include <cstring>

#include "xsched/utils/xassert.h"
#include "xsched/preempt/xqueue/xqueue.h"
#include "xsched/cudla/hal/event_pool.h"
#include "xsched/cudla/hal/cudla_command.h"

using namespace xsched::cudla;
using namespace xsched::preempt;

CudlaCommand::~CudlaCommand()
{
    if (following_event_ == nullptr) return;
    EventPool::Instance().Push(following_event_);
}

void CudlaCommand::Synchronize()
{
    XASSERT(following_event_ != nullptr,
            "following_event_ is nullptr, EnableSynchronization() should be called first");
    CUDART_ASSERT(RtDriver::EventSynchronize(following_event_));
}

bool CudlaCommand::Synchronizable()
{
    return following_event_ != nullptr;
}

bool CudlaCommand::EnableSynchronization()
{
    following_event_ = (cudaEvent_t)EventPool::Instance().Pop();
    return following_event_ != nullptr;
}

cudaError_t CudlaCommand::LaunchWrapper(cudaStream_t stream)
{
    cudaError_t ret = Launch(stream);
    if (UNLIKELY(ret != cudaSuccess)) return ret;
    if (following_event_ != nullptr) ret = RtDriver::EventRecord(following_event_, stream);
    return ret;
}

CudlaTaskCommand::CudlaTaskCommand(cudlaDevHandle const dev_handle, const cudlaTask * const tasks,
                                   uint32_t const num_tasks, uint32_t const flags)
    : CudlaCommand(kCommandPropertyNone)
    , dev_handle_(dev_handle), num_tasks_(num_tasks), flags_(flags)
{
    XASSERT(tasks != nullptr, "tasks should not be nullptr");
    tasks_ = (cudlaTask *)malloc(sizeof(cudlaTask) * num_tasks);
    memcpy(tasks_, tasks, sizeof(cudlaTask) * num_tasks);
}

CudlaTaskCommand::~CudlaTaskCommand()
{
    if (tasks_ != nullptr) free(tasks_);
}

CudlaEventRecordCommand::CudlaEventRecordCommand(cudaEvent_t event)
    : CudlaCommand(kCommandPropertyIdempotent), event_(event)
{
    XASSERT(event_ != nullptr, "cuda event should not be nullptr");
}