#include "xsched/utils/xassert.h"
#include "xsched/cuda/hal/level2/cuda_queue.h"
#include "xsched/cuda/hal/common/cuda_assert.h"
using namespace xsched::cuda;
using namespace xsched::preempt;
CudaQueueLv2::CudaQueueLv2(CUstream stream): CudaQueueLv1(stream)
{
instrument_manager_ = std::make_unique<InstrumentManager>(context_, cudevice_);
}
void CudaQueueLv2::Launch(std::shared_ptr<HwCommand> hw_cmd)
{
auto kernel = std::dynamic_pointer_cast<CudaKernelCommand>(hw_cmd);
if (kernel != nullptr) return instrument_manager_->Launch(kernel, kStream, level_);
auto cuda_cmd = std::dynamic_pointer_cast<CudaCommand>(hw_cmd);
XASSERT(cuda_cmd != nullptr, "hw_cmd is not a CudaCommand");
CUDA_ASSERT(cuda_cmd->LaunchWrapper(kStream));
}
void CudaQueueLv2::Deactivate()
{
XASSERT(level_ >= kPreemptLevelDeactivate, "Deactivate() not supported on level-%d", level_);
instrument_manager_->Deactivate();
}
void CudaQueueLv2::Reactivate(const preempt::CommandLog &log)
{
XASSERT(level_ >= kPreemptLevelDeactivate, "Reactivate() not supported on level-%d", level_);
this->Synchronize();
uint64_t resume_cmd_idx = instrument_manager_->Reactivate();
if (resume_cmd_idx == 0) return;
for (auto cmd : log) {
if (cmd->GetIdx() < (int64_t)resume_cmd_idx) continue;
this->Launch(cmd);
}
}
void CudaQueueLv2::OnPreemptLevelChange(XPreemptLevel level)
{
XASSERT(level <= kPreemptLevelDeactivate, "unsupported level: %d", level);
level_ = level;
}
void CudaQueueLv2::OnHwCommandSubmit(std::shared_ptr<preempt::HwCommand> hw_cmd)
{
if (level_ < kPreemptLevelDeactivate) return;
auto kernel = std::dynamic_pointer_cast<CudaKernelCommand>(hw_cmd);
if (kernel != nullptr) instrument_manager_->Instrument(kernel);
}
CUresult CudaQueueLv2::DirectLaunch(std::shared_ptr<CudaKernelCommand> kernel,
CUcontext ctx, CUstream stream)
{
auto instrument_ctx = InstrumentContext::Instance(ctx);
return instrument_ctx->Launch(kernel, stream, kKernelLaunchOriginal);
}