* Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Thread pool.
*/
#include "datasystem/common/util/thread_pool.h"
#include <system_error>
#include <thread>
#ifdef WITH_TESTS
#include "datasystem/common/inject/inject_point.h"
#endif
#include "datasystem/common/util/format.h"
#include "datasystem/common/util/thread.h"
#include "datasystem/common/util/timer.h"
#include "datasystem/common/log/log.h"
namespace datasystem {
ThreadWorkers::~ThreadWorkers()
{
this->Join();
}
void ThreadWorkers::Join()
{
for (auto &workerPair : *this) {
auto thread = &workerPair.second;
if (thread->joinable()) {
thread->join();
}
}
}
void ThreadPool::DoThreadWork()
{
if (!Thread::SetCurrentThreadNice(0)) {
LOG(WARNING) << "Failed to set nice for thread pool [" << name_ << "], nice=" << 0 << ", errno=" << errno;
}
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->mtx_);
if (!this->proceedCV_.wait_for(lock, threadIdleTimeoutMs_,
[this] { return this->shutDown_ || !this->taskQ_.empty(); })) {
if (GetThreadsNum() > minThreadNum_) {
auto tid = std::this_thread::get_id();
DestroyUnuseWorker(tid);
return;
}
continue;
}
if (this->shutDown_ && (droppable_ || this->taskQ_.empty())) {
return;
}
if (runningThreadsNum_ == 0) {
taskLastFinishTime_ = GetSteadyClockTimeStampUs();
}
runningThreadsNum_++;
UpdateMaxAtomic(maxRunningInPeriod_, runningThreadsNum_.load());
task = std::move(this->taskQ_.front());
this->taskQ_.pop();
}
{
auto start = std::chrono::steady_clock::now();
task();
auto elapsed = std::chrono::steady_clock::now() - start;
auto elapsedNs = std::chrono::duration_cast<std::chrono::nanoseconds>(elapsed).count();
taskLastFinishTime_ = GetSteadyClockTimeStampUs();
runningThreadsNum_--;
tasksCompleted_.fetch_add(1, std::memory_order_relaxed);
totalWorkTimeNs_.fetch_add(elapsedNs, std::memory_order_relaxed);
}
}
}
void ThreadPool::AddThread()
{
auto thread = Thread([this] { this->DoThreadWork(); });
thread.set_name(name_);
std::lock_guard<std::shared_timed_mutex> workerLock(workersMtx_);
workers_[thread.get_id()] = std::move(thread);
}
void ThreadPool::DestroyUnuseWorker(std::thread::id tid)
{
std::lock_guard<std::shared_timed_mutex> workerLock(workersMtx_);
if (!shutDown_ && workers_.find(tid) != workers_.end()) {
if (workers_[tid].joinable()) {
workers_[tid].detach();
}
(void)workers_.erase(tid);
}
}
bool ThreadPool::IsPoolFull()
{
std::shared_lock<std::shared_timed_mutex> lock(workersMtx_);
if (taskQ_.size() + runningThreadsNum_ >= maxThreadNum_) {
return true;
}
return false;
}
void ThreadPool::TryToAddThreadIfNeeded()
{
{
std::shared_lock<std::shared_timed_mutex> lock(workersMtx_);
auto threadNum = workers_.size();
if (threadNum >= maxThreadNum_ || threadNum >= taskQ_.size() + runningThreadsNum_) {
return;
}
}
try {
AddThread();
} catch (std::system_error &sysErr) {
if (GetThreadsNum() == 0) {
throw;
} else {
std::string errMsg = std::string(sysErr.what()) + ", cannot acquire resources";
LOG(ERROR) << errMsg;
}
}
}
size_t ThreadPool::GetThreadsNum()
{
std::shared_lock<std::shared_timed_mutex> workerLock(workersMtx_);
return workers_.size();
}
size_t ThreadPool::GetRunningTasksNum()
{
return runningThreadsNum_;
}
size_t ThreadPool::GetWaitingTasksNum()
{
return taskQ_.size();
}
std::string ThreadPool::GetStatistics()
{
auto totalNum = GetThreadsNum();
auto idleNum = totalNum - GetRunningTasksNum();
return FormatString("idle(%ld),total(%ld),wait(%ld)", idleNum, totalNum, GetWaitingTasksNum());
}
ThreadPool::ThreadPoolUsage ThreadPool::GetThreadPoolUsage()
{
ThreadPoolUsage threadPoolUsage;
threadPoolUsage.currentTotalNum = GetThreadsNum();
threadPoolUsage.runningTasksNum = GetRunningTasksNum();
threadPoolUsage.waitingTaskNum = GetWaitingTasksNum();
threadPoolUsage.maxThreadNum = maxThreadNum_;
threadPoolUsage.threadPoolUsage =
maxThreadNum_ > 0 ? threadPoolUsage.runningTasksNum / static_cast<float>(maxThreadNum_) : 0.0f;
threadPoolUsage.taskLastFinishTime = taskLastFinishTime_;
return threadPoolUsage;
}
void ThreadPool::UpdateMaxAtomic(std::atomic<uint64_t> &counter, uint64_t value)
{
uint64_t current = counter.load(std::memory_order_relaxed);
while (value > current) {
if (counter.compare_exchange_weak(current, value, std::memory_order_relaxed)) {
break;
}
}
}
ThreadPool::ThreadPoolUsage ThreadPool::GetAndResetIntervalStats()
{
ThreadPoolUsage usage;
usage.currentTotalNum = GetThreadsNum();
usage.maxThreadNum = maxThreadNum_;
usage.runningTasksNum = GetRunningTasksNum();
usage.waitingTaskNum = GetWaitingTasksNum();
usage.threadPoolUsage = maxThreadNum_ > 0 ? usage.runningTasksNum / static_cast<float>(maxThreadNum_) : 0.0f;
auto running = static_cast<uint64_t>(usage.runningTasksNum);
auto waiting = static_cast<uint64_t>(usage.waitingTaskNum);
usage.maxRunningInPeriod = std::max(maxRunningInPeriod_.exchange(running, std::memory_order_relaxed), running);
usage.tasksCompletedDelta = tasksCompleted_.exchange(0, std::memory_order_relaxed);
usage.maxWaitingInPeriod = std::max(maxWaitingInPeriod_.exchange(waiting, std::memory_order_relaxed), waiting);
usage.totalWorkTimeNs = totalWorkTimeNs_.exchange(0, std::memory_order_relaxed);
usage.taskLastFinishTime = taskLastFinishTime_;
return usage;
}
ThreadPool::ThreadPool(size_t minThreadNum, size_t maxThreadNum, std::string name, bool droppable,
int threadIdleTimeoutMs)
: shutDown_(false),
joined_(false),
droppable_(droppable),
minThreadNum_(minThreadNum),
maxThreadNum_(maxThreadNum),
name_(std::move(name)),
threadIdleTimeoutMs_(threadIdleTimeoutMs),
warnLevel_(WarnLevel::HIGH)
{
if (maxThreadNum_ == 0) {
if (minThreadNum_ == 0) {
throw std::runtime_error("ThreadPool: minThreadNum == maxThreadNum == 0, won't create any thread.");
}
maxThreadNum_ = minThreadNum_;
}
if (minThreadNum_ > maxThreadNum_) {
minThreadNum_ = maxThreadNum_;
LOG(WARNING) << FormatString("ThreadPool: minThreadNum %zu > maxThreadNum, adjust minThreadNum to %zu",
minThreadNum_, maxThreadNum_);
}
workers_.reserve(minThreadNum_);
for (size_t i = 0; i < minThreadNum_; ++i) {
AddThread();
}
}
ThreadPool::~ThreadPool()
{
bool isShutDown = false;
bool isJoined = false;
{
std::unique_lock<std::mutex> lock(mtx_);
isShutDown = shutDown_;
isJoined = joined_;
}
if (!isShutDown) {
ShutDown();
}
if (!isJoined) {
Join();
}
}
void ThreadPool::Join()
{
workers_.Join();
joined_ = true;
}
void ThreadPool::ShutDown()
{
{
std::unique_lock<std::mutex> lock(mtx_);
shutDown_ = true;
}
proceedCV_.notify_all();
}
void ThreadPool::WarnIfNeed()
{
if (name_.empty() || warnLevel_ == WarnLevel::NO_WARN) {
return;
}
if (GetWaitingTasksNum() < 1) {
return;
}
const float oneHundredPercent = 100.0;
const float sixtyPercentThreshold = 60.0;
const float eightyPercentThreshold = 80.0;
float ratio = static_cast<float>(GetRunningTasksNum()) / maxThreadNum_ * oneHundredPercent;
std::string msg =
FormatString("Thread pool [%s] usage: %.1lf%%, waiting tasks: %d", name_, ratio, GetWaitingTasksNum());
if (warnLevel_ == WarnLevel::LOW) {
const int oneHundredTime = 10;
const int eightyPercentTime = 60;
const int sixtyPercentTime = 100;
if (ratio >= oneHundredPercent) {
LOG_EVERY_T(WARNING, oneHundredTime) << msg << ", thread full";
} else if (ratio >= eightyPercentThreshold) {
LOG_EVERY_T(WARNING, eightyPercentTime) << msg << ", exceeds 80%";
} else if (ratio >= sixtyPercentThreshold) {
LOG_EVERY_T(WARNING, sixtyPercentTime) << msg << ", exceeds 60%";
}
} else {
const int oneHundredFreq = 50;
const int eightyPercentFreq = 500;
const int sixtyPercentFreq = 1000;
if (ratio >= oneHundredPercent) {
LOG_FIRST_AND_EVERY_N(WARNING, oneHundredFreq) << msg << ", thread full";
} else if (ratio >= eightyPercentThreshold) {
LOG_FIRST_AND_EVERY_N(WARNING, eightyPercentFreq) << msg << ", exceeds 80%";
} else if (ratio >= sixtyPercentThreshold) {
LOG_FIRST_AND_EVERY_N(WARNING, sixtyPercentFreq) << msg << ", exceeds 60%";
}
}
}
}