#include "base/task/thread_pool/job_task_source.h"
#include <bit>
#include <type_traits>
#include <utility>
#include "base/check_op.h"
#include "base/feature_list.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/ptr_util.h"
#include "base/notreached.h"
#include "base/task/common/checked_lock.h"
#include "base/task/task_features.h"
#include "base/task/thread_pool/pooled_task_runner_delegate.h"
#include "base/threading/thread_restrictions.h"
#include "base/time/time.h"
#include "base/time/time_override.h"
#include "base/trace_event/trace_event.h"
namespace base::internal {
namespace {
bool g_job_priority_boosting = false;
BASE_FEATURE(kJobPriorityBoosting, base::FEATURE_DISABLED_BY_DEFAULT);
constexpr size_t kMaxWorkersPerJob = 32;
static_assert(
kMaxWorkersPerJob <=
std::numeric_limits<
std::invoke_result<decltype(&JobDelegate::GetTaskId),
JobDelegate>::type>::max(),
"AcquireTaskId return type isn't big enough to fit kMaxWorkersPerJob");
}
JobTaskSource::State::State() = default;
JobTaskSource::State::~State() = default;
JobTaskSource::State::Value JobTaskSource::State::Cancel() {
return {value_.fetch_or(kCanceledMask, std::memory_order_relaxed)};
}
JobTaskSource::State::Value JobTaskSource::State::DecrementWorkerCount() {
const uint32_t value_before_sub =
value_.fetch_sub(kWorkerCountIncrement, std::memory_order_relaxed);
DCHECK((value_before_sub >> kWorkerCountBitOffset) > 0);
return {value_before_sub};
}
JobTaskSource::State::Value JobTaskSource::State::IncrementWorkerCount() {
uint32_t value_before_add =
value_.fetch_add(kWorkerCountIncrement, std::memory_order_relaxed);
DCHECK((value_before_add >> kWorkerCountBitOffset) < ((1 << 8) - 1));
return {value_before_add};
}
JobTaskSource::State::Value JobTaskSource::State::Load() const {
return {value_.load(std::memory_order_relaxed)};
}
JobTaskSource::JoinFlag::JoinFlag() = default;
JobTaskSource::JoinFlag::~JoinFlag() = default;
void JobTaskSource::JoinFlag::Reset() {
value_.store(kNotWaiting, std::memory_order_relaxed);
}
void JobTaskSource::JoinFlag::SetWaiting() {
value_.store(kWaitingForWorkerToYield, std::memory_order_relaxed);
}
bool JobTaskSource::JoinFlag::ShouldWorkerYield() {
return value_.fetch_and(kWaitingForWorkerToSignal,
std::memory_order_relaxed) ==
kWaitingForWorkerToYield;
}
bool JobTaskSource::JoinFlag::ShouldWorkerSignal() {
return value_.exchange(kNotWaiting, std::memory_order_relaxed) != kNotWaiting;
}
void JobTaskSource::InitializeFeatures() {
g_job_priority_boosting = FeatureList::IsEnabled(kJobPriorityBoosting);
}
JobTaskSource::JobTaskSource(const Location& from_here,
const TaskTraits& traits,
RepeatingCallback<void(JobDelegate*)> worker_task,
MaxConcurrencyCallback max_concurrency_callback,
PooledTaskRunnerDelegate* delegate)
: TaskSource(traits, TaskSourceExecutionMode::kJob),
max_concurrency_callback_(std::move(max_concurrency_callback)),
worker_task_(std::move(worker_task)),
primary_task_(base::BindRepeating(
[](JobTaskSource* self) {
CheckedLock::AssertNoLockHeldOnCurrentThread();
JobDelegate job_delegate{self, self->delegate_};
self->worker_task_.Run(&job_delegate);
},
base::Unretained(this))),
task_metadata_(from_here),
ready_time_(TimeTicks::Now()),
delegate_(delegate) {
DCHECK(delegate_);
task_metadata_.sequence_num = -1;
}
JobTaskSource::~JobTaskSource() {
DCHECK_EQ(state_.Load().worker_count(), 0U);
}
ExecutionEnvironment JobTaskSource::GetExecutionEnvironment() {
return {SequenceToken::Create()};
}
void JobTaskSource::WillEnqueue(int sequence_num, TaskAnnotator& annotator) {
if (task_metadata_.sequence_num != -1) {
return;
}
task_metadata_.sequence_num = sequence_num;
annotator.WillQueueTask("ThreadPool_PostJob", &task_metadata_);
}
bool JobTaskSource::WillJoin() {
TRACE_EVENT("base", "Job.WaitForParticipationOpportunity");
CheckedAutoLock auto_lock(worker_lock_);
DCHECK(!worker_released_condition_);
worker_lock_.CreateConditionVariableAndEmplace(worker_released_condition_);
worker_released_condition_->declare_only_used_while_idle();
const auto state_before_add = state_.IncrementWorkerCount();
if (!state_before_add.is_canceled() &&
state_before_add.worker_count() <
GetMaxConcurrency(state_before_add.worker_count())) {
return true;
}
for (auto& [_, worker_priority] : workers_priority_) {
worker_priority.BoostPriority(PlatformThread::GetCurrentThreadType());
}
return WaitForParticipationOpportunity();
}
bool JobTaskSource::RunJoinTask() {
JobDelegate job_delegate{this, nullptr};
worker_task_.Run(&job_delegate);
const auto state = TS_UNCHECKED_READ(state_).Load();
if (!state.is_canceled() &&
state.worker_count() <= GetMaxConcurrency(state.worker_count() - 1)) {
return true;
}
TRACE_EVENT("base", "Job.WaitForParticipationOpportunity");
CheckedAutoLock auto_lock(worker_lock_);
return WaitForParticipationOpportunity();
}
void JobTaskSource::Cancel(TaskSource::Transaction* transaction) {
TS_UNCHECKED_READ(state_).Cancel();
}
bool JobTaskSource::WaitForParticipationOpportunity() {
DCHECK(!join_flag_.IsWaiting());
auto state = state_.Load();
size_t max_concurrency = GetMaxConcurrency(state.worker_count() - 1);
while (!((state.worker_count() <= max_concurrency && !state.is_canceled()) ||
state.worker_count() == 1)) {
join_flag_.SetWaiting();
worker_released_condition_->Wait();
state = state_.Load();
max_concurrency = GetMaxConcurrency(state.worker_count() - 1);
}
join_flag_.Reset();
if (state.worker_count() <= max_concurrency && !state.is_canceled()) {
return true;
}
DCHECK_EQ(state.worker_count(), 1U);
DCHECK(state.is_canceled() || max_concurrency == 0U);
state_.DecrementWorkerCount();
state_.Cancel();
return false;
}
TaskSource::RunStatus JobTaskSource::WillRunTask() {
CheckedAutoLock auto_lock(worker_lock_);
is_queued_ = false;
auto state_before_add = state_.Load();
if (state_before_add.is_canceled()) {
return RunStatus::kDisallowed;
}
const size_t max_concurrency =
GetMaxConcurrency(state_before_add.worker_count());
if (state_before_add.worker_count() < max_concurrency) {
state_before_add = state_.IncrementWorkerCount();
}
const size_t worker_count_before_add = state_before_add.worker_count();
if (worker_count_before_add >= max_concurrency) {
return RunStatus::kDisallowed;
}
if (g_job_priority_boosting) {
auto [_, inserted] = workers_priority_.emplace(
std::piecewise_construct,
std::forward_as_tuple(PlatformThread::CurrentId()),
std::forward_as_tuple());
CHECK(inserted);
}
DCHECK_LT(worker_count_before_add, max_concurrency);
TaskSource::RunStatus status =
(max_concurrency == worker_count_before_add + 1)
? RunStatus::kAllowedSaturated
: RunStatus::kAllowedNotSaturated;
is_queued_ = (status == RunStatus::kAllowedNotSaturated);
return status;
}
size_t JobTaskSource::GetRemainingConcurrency() const {
const auto state = TS_UNCHECKED_READ(state_).Load();
if (state.is_canceled()) {
return 0;
}
const size_t max_concurrency = GetMaxConcurrency(state.worker_count());
if (state.worker_count() > max_concurrency) {
return 0;
}
return max_concurrency - state.worker_count();
}
bool JobTaskSource::IsActive() const {
CheckedAutoLock auto_lock(worker_lock_);
auto state = state_.Load();
return GetMaxConcurrency(state.worker_count()) != 0 ||
state.worker_count() != 0;
}
size_t JobTaskSource::GetWorkerCount() const {
return TS_UNCHECKED_READ(state_).Load().worker_count();
}
void JobTaskSource::NotifyConcurrencyIncrease() {
if (GetRemainingConcurrency() == 0) {
return;
}
bool should_queue;
{
CheckedAutoLock auto_lock(worker_lock_);
if (join_flag_.ShouldWorkerSignal()) {
worker_released_condition_->Signal();
}
should_queue = !std::exchange(is_queued_, true);
}
if (should_queue) {
delegate_->EnqueueJobTaskSource(this);
}
}
size_t JobTaskSource::GetMaxConcurrency() const {
return GetMaxConcurrency(TS_UNCHECKED_READ(state_).Load().worker_count());
}
size_t JobTaskSource::GetMaxConcurrency(size_t worker_count) const {
return std::min(max_concurrency_callback_.Run(worker_count),
kMaxWorkersPerJob);
}
uint8_t JobTaskSource::AcquireTaskId() {
static_assert(kMaxWorkersPerJob <= sizeof(assigned_task_ids_) * 8,
"TaskId bitfield isn't big enough to fit kMaxWorkersPerJob.");
uint32_t assigned_task_ids =
assigned_task_ids_.load(std::memory_order_relaxed);
uint32_t new_assigned_task_ids = 0;
int task_id = 0;
do {
task_id = std::countr_one(assigned_task_ids);
new_assigned_task_ids = assigned_task_ids | (uint32_t(1) << task_id);
} while (!assigned_task_ids_.compare_exchange_weak(
assigned_task_ids, new_assigned_task_ids, std::memory_order_acquire,
std::memory_order_relaxed));
return static_cast<uint8_t>(task_id);
}
void JobTaskSource::ReleaseTaskId(uint8_t task_id) {
uint32_t previous_task_ids = assigned_task_ids_.fetch_and(
~(uint32_t(1) << task_id), std::memory_order_release);
DCHECK(previous_task_ids & (uint32_t(1) << task_id));
}
bool JobTaskSource::ShouldYield() {
return TS_UNCHECKED_READ(join_flag_).ShouldWorkerYield() ||
TS_UNCHECKED_READ(state_).Load().is_canceled();
}
Task JobTaskSource::TakeTask(TaskSource::Transaction* transaction) {
DCHECK_GT(TS_UNCHECKED_READ(state_).Load().worker_count(), 0U);
DCHECK(primary_task_);
return {task_metadata_, primary_task_};
}
bool JobTaskSource::DidProcessTask(TaskSource::Transaction* ) {
CheckedAutoLock auto_lock(worker_lock_);
const auto state_before_sub = state_.DecrementWorkerCount();
if (g_job_priority_boosting) {
workers_priority_.erase(PlatformThread::CurrentId());
}
if (join_flag_.ShouldWorkerSignal()) {
worker_released_condition_->Signal();
}
if (state_before_sub.is_canceled()) {
return false;
}
DCHECK_GT(state_before_sub.worker_count(), 0U);
bool reenqueue = state_before_sub.worker_count() <=
GetMaxConcurrency(state_before_sub.worker_count() - 1);
is_queued_ |= reenqueue;
return reenqueue;
}
bool JobTaskSource::WillReEnqueue(TimeTicks now,
TaskSource::Transaction* ) {
return true;
}
bool JobTaskSource::OnBecomeReady() {
return false;
}
TaskSourceSortKey JobTaskSource::GetSortKey() const {
return TaskSourceSortKey(priority_racy(), ready_time_,
TS_UNCHECKED_READ(state_).Load().worker_count());
}
TimeTicks JobTaskSource::GetDelayedSortKey() const {
return TimeTicks();
}
bool JobTaskSource::HasReadyTasks(TimeTicks now) const {
NOTREACHED();
}
std::optional<Task> JobTaskSource::Clear(TaskSource::Transaction* transaction) {
Cancel();
return std::nullopt;
}
}