#include "xsched/utils/xassert.h"
#include "xsched/preempt/sched/agent.h"
#include "xsched/preempt/xqueue/xqueue.h"
#include "xsched/preempt/xqueue/async_xqueue.h"
using namespace xsched::sched;
using namespace xsched::preempt;
AsyncXQueue::AsyncXQueue(std::shared_ptr<HwQueue> hwq, XPreemptLevel level,
int64_t threshold, int64_t batch_size)
: XQueue(kXQueueImplTypeAsync,
kQueueFeatureAsyncSubmit |
(hwq->SupportDynamicLevel() ? kQueueFeatureDynamicLevel : 0) |
kQueueFeatureDynamicThreshold | kQueueFeatureDynamicBatchSize |
kQueueFeatureSyncSuspend | kQueueFeatureResumeDropCommands ,
hwq)
, level_(level)
, cmd_buf_(std::make_shared<CommandBuffer>(kHandle))
, launch_worker_(std::make_shared<LaunchWorker>(hwq, cmd_buf_, level, threshold, batch_size))
{
XASSERT(level_ > kPreemptLevelUnknown && level_ < kPreemptLevelMax,
"invalid preempt level: %d", level_);
auto wait_cmd = cmd_buf_->EnqueueXQueueWaitAllCommand();
wait_cmd->Wait();
kHwQueue->OnPreemptLevelChange(level_);
SchedAgent::SendEvent(std::make_shared<XQueueCreateEvent>(kHandle, kDevice, level,
threshold, batch_size));
}
AsyncXQueue::~AsyncXQueue()
{
terminated_.store(true);
this->Resume(true);
auto destroy_command = std::make_shared<XQueueDestroyCommand>();
cmd_buf_->Enqueue(destroy_command);
destroy_command->Wait();
SchedAgent::SendEvent(std::make_shared<XQueueDestroyEvent>(kHandle));
}
void AsyncXQueue::Submit(std::shared_ptr<HwCommand> hw_cmd)
{
hw_cmd->OnSubmit(shared_from_this());
hw_cmd->SetIdx(next_hw_cmd_idx_.fetch_add(1));
kHwQueue->OnHwCommandSubmit(hw_cmd);
cmd_buf_->Enqueue(hw_cmd);
if (hw_cmd->GetProps(kCommandPropertyBlockingSubmit)) hw_cmd->WaitUntil(kCommandStateInFlight);
}
std::shared_ptr<XQueueWaitAllCommand> AsyncXQueue::SubmitWaitAll()
{
return cmd_buf_->EnqueueXQueueWaitAllCommand();
}
void AsyncXQueue::WaitAll()
{
auto sync_command = cmd_buf_->EnqueueXQueueWaitAllCommand();
sync_command->Wait();
}
void AsyncXQueue::Wait(std::shared_ptr<HwCommand> hw_cmd)
{
if (!hw_cmd->Synchronizable()) {
hw_cmd->WaitUntil(kCommandStateCompleted);
return;
}
if (hw_cmd->GetState() >= kCommandStateCompleted) return;
hw_cmd->WaitUntil(kCommandStateInFlight);
launch_worker_->SyncCmd(hw_cmd);
}
XQueueState AsyncXQueue::Query()
{
return cmd_buf_->GetXQueueState();
}
int64_t AsyncXQueue::GetHwCommandCount()
{
return next_hw_cmd_idx_.load() - 1;
}
void AsyncXQueue::Suspend(int64_t flags)
{
if (terminated_.load()) return;
bool expected = false;
if (!suspended_.compare_exchange_strong(expected, true)) return;
launch_worker_->Pause();
if (level_ >= kPreemptLevelDeactivate) kHwQueue->Deactivate();
if (level_ >= kPreemptLevelInterrupt) kHwQueue->Interrupt();
if (flags & kQueueSuspendFlagSyncHwQueue) kHwQueue->Synchronize();
}
void AsyncXQueue::Resume(int64_t flags)
{
bool expected = true;
if (!suspended_.compare_exchange_strong(expected, false)) return;
if (level_ == kPreemptLevelBlock) {
launch_worker_->Resume();
return;
}
if (!(flags & kQueueResumeFlagDropCommands)) {
const CommandLog &log = launch_worker_->GetCommandLog();
if (level_ >= kPreemptLevelInterrupt) kHwQueue->Restore(log);
if (level_ >= kPreemptLevelDeactivate) kHwQueue->Reactivate(log);
launch_worker_->Resume();
return;
}
CommandLog empty_log = {};
if (level_ >= kPreemptLevelInterrupt) kHwQueue->Restore(empty_log);
if (level_ >= kPreemptLevelDeactivate) kHwQueue->Reactivate(empty_log);
cmd_buf_->DropAll();
int64_t drop_idx = next_hw_cmd_idx_.load() - 1;
launch_worker_->ResumeAndDrop(drop_idx);
}
void AsyncXQueue::SetPreemptLevel(XPreemptLevel level)
{
XASSERT(level > kPreemptLevelUnknown && level < kPreemptLevelMax,
"invalid preempt level: %d", level);
if (!this->GetFeatures(kQueueFeatureDynamicLevel)) {
XASSERT(level == level_, "AsyncXQueue does not support dynamic level");
return;
}
if (level > kHwQueue->GetMaxSupportedLevel()) {
XWARN("preempt level %d is not supported by the HwQueue, "
"max supported level is %d", level, kHwQueue->GetMaxSupportedLevel());
return;
}
this->WaitAll();
kHwQueue->OnPreemptLevelChange(level);
launch_worker_->SetPreemptLevel(level);
level_ = level;
SchedAgent::SendEvent(std::make_shared<XQueueConfigUpdateEvent>(kHandle, kDevice, level_,
launch_worker_->GetThreshold(), launch_worker_->GetBatchSize()));
}
void AsyncXQueue::SetLaunchConfig(int64_t threshold, int64_t batch_size)
{
if (threshold <= 0 && batch_size <= 0) return;
if (threshold <= 0) {
threshold = launch_worker_->GetThreshold();
} else {
XASSERT(this->GetFeatures(kQueueFeatureDynamicThreshold),
"AsyncXQueue does not support dynamic threshold");
}
if (batch_size <= 0) {
batch_size = launch_worker_->GetBatchSize();
} else {
XASSERT(this->GetFeatures(kQueueFeatureDynamicBatchSize),
"AsyncXQueue does not support dynamic batch size");
}
XASSERT(threshold >= batch_size,
"command threshold (" FMT_64D ") must not be smaller than "
"command batch size (" FMT_64D ")", threshold, batch_size);
this->WaitAll();
launch_worker_->SetLaunchConfig(threshold, batch_size);
SchedAgent::SendEvent(std::make_shared<XQueueConfigUpdateEvent>(kHandle, kDevice, level_,
launch_worker_->GetThreshold(), launch_worker_->GetBatchSize()));
}