#include "xsched/utils/xassert.h"
#include "xsched/cuda/hal/level3/cuda_queue.h"
#include "xsched/cuda/hal/common/cuda_assert.h"
using namespace xsched::cuda;
using namespace xsched::preempt;
CudaLv3Implementation xsched::cuda::GetCudaLv3Implementation()
{
static CudaLv3Implementation impl = []()->CudaLv3Implementation {
char *str = std::getenv(XSCHED_CUDA_LV3_IMPL_ENV_NAME);
if (str == nullptr) return kCudaLv3ImplementationTrap;
if (strcmp(str, "TSG") == 0) return kCudaLv3ImplementationTsg;
return kCudaLv3ImplementationTrap;
}();
return impl;
}
CudaQueueLv3Trap::CudaQueueLv3Trap(CUstream stream): CudaQueueLv2(stream)
{
interrupt_context_ = InterruptContext::Instance(context_);
}
void CudaQueueLv3Trap::Interrupt()
{
XASSERT(level_ >= kPreemptLevelInterrupt, "Interrupt() not supported on level-%d", level_);
CUDA_ASSERT(Driver::StreamSynchronize(instrument_manager_->OpStream()));
interrupt_context_->Interrupt();
}
void CudaQueueLv3Trap::Restore(const CommandLog &)
{
XASSERT(level_ >= kPreemptLevelInterrupt, "Restore() not supported on level-%d", level_);
}
void CudaQueueLv3Trap::OnPreemptLevelChange(XPreemptLevel level)
{
XASSERT(level <= kPreemptLevelInterrupt, "unsupported level: %d", level);
if (level == kPreemptLevelInterrupt) {
instrument_manager_->NotifyTrapInstrumented();
interrupt_context_->InstrumentTrapHandler();
}
level_ = level;
}
void CudaQueueLv3Trap::OnHwCommandSubmit(std::shared_ptr<preempt::HwCommand> cmd)
{
if (level_ < kPreemptLevelDeactivate) return;
auto kernel = std::dynamic_pointer_cast<CudaKernelCommand>(cmd);
if (kernel == nullptr) return;
instrument_manager_->Instrument(kernel);
kernel->killable = true;
}
CudaQueueLv3Tsg::CudaQueueLv3Tsg(CUstream stream): CudaQueueLv1(stream)
{
tsg_context_ = TsgContext::Instance(context_);
}
void CudaQueueLv3Tsg::Interrupt()
{
XASSERT(level_ >= kPreemptLevelInterrupt, "Interrupt() not supported on level-%d", level_);
tsg_context_->Interrupt();
}
void CudaQueueLv3Tsg::Restore(const CommandLog &)
{
XASSERT(level_ >= kPreemptLevelInterrupt, "Restore() not supported on level-%d", level_);
tsg_context_->Restore();
}
void CudaQueueLv3Tsg::OnPreemptLevelChange(XPreemptLevel level)
{
XASSERT(level <= kPreemptLevelInterrupt, "unsupported level: %d", level);
level_ = level;
}