#include "xsched/utils/xassert.h"
#include "xsched/preempt/xqueue/launch_worker.h"
using namespace xsched::utils;
using namespace xsched::preempt;
LaunchWorker::LaunchWorker(std::shared_ptr<HwQueue> hwq, std::shared_ptr<CommandBuffer> buf,
XPreemptLevel level, int64_t threshold, int64_t batch_size)
: level_(level), threshold_(threshold), batch_size_(batch_size)
, kHwq(hwq), kCmdBuf(buf), mtx_(std::make_unique<MCSLock>())
{
XASSERT(threshold_ >= batch_size_, "threshold must not be smaller than command batch size");
worker_thread_ = std::make_unique<std::thread>([this](){
this->kHwq->OnXQueueCreate();
this->WorkerLoop();
});
}
LaunchWorker::~LaunchWorker()
{
worker_thread_->join();
}
void LaunchWorker::Pause()
{
mtx_->lock();
state_ = kWorkerStatePaused;
pause_count_ += 1;
mtx_->unlock();
}
void LaunchWorker::Resume()
{
mtx_->lock();
state_ = kWorkerStateRunning;
mtx_->unlock();
cv_.notify_all();
}
void LaunchWorker::ResumeAndDrop(int64_t drop_idx)
{
std::unique_lock<MutexLock> lock(*mtx_);
sync_cmd_log_.clear();
for (auto hw_cmd : cmd_log_) hw_cmd->SetState(kCommandStateCompleted);
cmd_log_.clear();
drop_idx_ = drop_idx;
state_ = kWorkerStateRunning;
lock.unlock();
cv_.notify_all();
}
void LaunchWorker::SyncAll()
{
std::unique_lock<MutexLock> lock(*mtx_);
SyncAllWithLock(std::move(lock));
}
void LaunchWorker::SyncCmd(std::shared_ptr<HwCommand> hw_cmd)
{
std::unique_lock<MutexLock> lock(*mtx_);
SyncCmdWithLock(std::move(lock), hw_cmd);
}
void LaunchWorker::WorkerLoop()
{
while (true) {
XDEBG("worker (%p) waiting for an xcmd", this);
std::shared_ptr<XCommand> xcmd = kCmdBuf->Dequeue();
XDEBG("worker (%p) got xcmd (%p)", this, xcmd.get());
switch (xcmd->GetType())
{
case kCommandTypeHardware:
{
auto hw_cmd = std::dynamic_pointer_cast<HwCommand>(xcmd);
XASSERT(hw_cmd != nullptr, "command type mismatch");
LaunchHwCommand(hw_cmd);
break;
}
case kCommandTypeHostFunction:
{
auto cmd = std::dynamic_pointer_cast<HostFunctionCommand>(xcmd);
XASSERT(cmd != nullptr, "command type mismatch");
cmd->SetState(kCommandStateInFlight);
std::unique_lock<MutexLock> lock(*mtx_);
lock = SyncAllWithLock(std::move(lock));
cmd->Execute();
lock.unlock();
cmd->SetState(kCommandStateCompleted);
break;
}
case kCommandTypeXQueueWaitAll:
{
xcmd->SetState(kCommandStateInFlight);
SyncAll();
xcmd->SetState(kCommandStateCompleted);
break;
}
case kCommandTypeBatchSynchronize:
{
xcmd->SetState(kCommandStateInFlight);
std::unique_lock<MutexLock> lock(*mtx_);
if (sync_cmd_log_.empty()) break;
auto command = sync_cmd_log_.front();
lock = SyncCmdWithLock(std::move(lock), command);
lock.unlock();
xcmd->SetState(kCommandStateCompleted);
break;
}
case kCommandTypeXQueueDestroy:
{
xcmd->SetState(kCommandStateInFlight);
std::unique_lock<MutexLock> lock(*mtx_);
lock = SyncAllWithLock(std::move(lock));
state_ = kWorkerStateTerminated;
lock.unlock();
cv_.notify_all();
xcmd->SetState(kCommandStateCompleted);
return;
}
default:
XASSERT(false, "unknown command type: %d", xcmd->GetType());
break;
}
}
}
void LaunchWorker::LaunchHwCommand(std::shared_ptr<HwCommand> hw_cmd)
{
XDEBG("launch hw_cmd (%p) idx " FMT_64D, hw_cmd.get(), hw_cmd->GetIdx());
bool wait_deactivatable = level_ >= kPreemptLevelDeactivate &&
(kCommandPropertyNone == hw_cmd->GetProps(
kCommandPropertyDeactivatable | kCommandPropertyIdempotent));
hw_cmd->BeforeLaunch();
std::unique_lock<MutexLock> lock(*mtx_);
if (wait_deactivatable) {
bool has_deactivatable = false;
std::shared_ptr<HwCommand> command_to_sync = nullptr;
for (auto it = cmd_log_.rbegin(); it != cmd_log_.rend(); ++it) {
if ((*it)->Synchronizable()) command_to_sync = *it;
if ((*it)->GetProps(kCommandPropertyDeactivatable)) {
has_deactivatable = true;
break;
}
}
if (has_deactivatable) {
lock = (command_to_sync == nullptr)
? SyncAllWithLock(std::move(lock))
: SyncCmdWithLock(std::move(lock), command_to_sync);
}
}
if (cmd_log_.size() >= (size_t)threshold_) {
std::shared_ptr<HwCommand> command_to_sync = nullptr;
int64_t front_command_idx = cmd_log_.front()->GetIdx();
for (auto cmd : sync_cmd_log_) {
if (cmd->GetIdx() >= front_command_idx) {
command_to_sync = cmd;
break;
}
}
lock = (command_to_sync == nullptr)
? SyncAllWithLock(std::move(lock))
: SyncCmdWithLock(std::move(lock), command_to_sync);
} else {
while (true) {
if (state_ == kWorkerStateRunning) {
break;
} else if (state_ == kWorkerStatePaused) {
cv_.wait(lock);
} else if (state_ == kWorkerStateTerminated) {
XASSERT(false, "Worker state should not be kWorkerStateTerminated.");
} else {
XASSERT(false, "Invalid worker state");
}
}
}
if (hw_cmd->GetIdx() <= drop_idx_) {
hw_cmd->SetState(kCommandStateCompleted);
return;
}
if (hw_cmd->GetIdx() - last_synchronizable_idx_ >= batch_size_) {
hw_cmd->EnableSynchronization();
}
auto callback = std::dynamic_pointer_cast<HwCallbackCommand>(hw_cmd);
if (callback != nullptr) {
XASSERT(callback->Launch(kHwq->GetHandle()) == kXSchedSuccess,
"failed to launch HwCallbackCommand (%p)", callback.get());
} else {
kHwq->Launch(hw_cmd);
}
hw_cmd->SetState(kCommandStateInFlight);
cmd_log_.emplace_back(hw_cmd);
if (hw_cmd->Synchronizable()) {
sync_cmd_log_.emplace_back(hw_cmd);
last_synchronizable_idx_ = hw_cmd->GetIdx();
}
}
std::unique_lock<MutexLock> LaunchWorker::SyncAllWithLock(std::unique_lock<MutexLock> lock)
{
while (true) {
while (true) {
if (state_ == kWorkerStateRunning) {
break;
} else if (state_ == kWorkerStatePaused) {
cv_.wait(lock);
} else if (state_ == kWorkerStateTerminated) {
return lock;
} else {
XASSERT(false, "Invalid worker state");
}
}
if (cmd_log_.size() == 0) return lock;
int64_t current_pause_cnt = pause_count_;
lock.unlock();
kHwq->Synchronize();
lock.lock();
if (current_pause_cnt == pause_count_) break;
}
sync_cmd_log_.clear();
for (auto hw_cmd : cmd_log_) hw_cmd->SetState(kCommandStateCompleted);
cmd_log_.clear();
return lock;
}
std::unique_lock<MutexLock> LaunchWorker::SyncCmdWithLock(std::unique_lock<utils::MutexLock> lock,
std::shared_ptr<HwCommand> hw_cmd)
{
XASSERT(hw_cmd->Synchronizable(), "The HwCommand should be synchronizable");
while (true) {
while (true) {
if (hw_cmd->GetState() >= kCommandStateCompleted) return lock;
if (state_ == kWorkerStateRunning) {
break;
} else if (state_ == kWorkerStatePaused) {
cv_.wait(lock);
} else if (state_ == kWorkerStateTerminated) {
return lock;
} else {
XASSERT(false, "Invalid worker state");
}
}
XCommandState state = hw_cmd->GetState();
XASSERT(state >= kCommandStateInFlight, "The syncing HwCommand is not launched");
if (state == kCommandStateCompleted) break;
int64_t current_pause_cnt = pause_count_;
lock.unlock();
hw_cmd->Synchronize();
lock.lock();
if (current_pause_cnt == pause_count_) break;
}
const int64_t current_command_idx = hw_cmd->GetIdx();
while (sync_cmd_log_.size() > 0) {
if (sync_cmd_log_.front()->GetIdx() > current_command_idx) break;
sync_cmd_log_.pop_front();
}
while (cmd_log_.size() > 0) {
auto front_command = cmd_log_.front();
if (front_command->GetIdx() > current_command_idx) break;
front_command->SetState(kCommandStateCompleted);
cmd_log_.pop_front();
}
return lock;
}