* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/thread_pool.h"
#include <algorithm>
#include <exception>
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_exception.h"
namespace mindspore {
namespace common {
#if ENABLE_D || ENABLE_GPU
const size_t kDeviceNum = 8;
#endif
const size_t kMaxThreadNum = 23;
ThreadPool::ThreadPool() {
size_t process_core_num = std::thread::hardware_concurrency() - 1;
if (process_core_num < 1) {
process_core_num = 1;
}
#if ENABLE_D || ENABLE_GPU
max_thread_num_ = process_core_num / kDeviceNum;
#else
max_thread_num_ = process_core_num;
#endif
if (max_thread_num_ < 1) {
max_thread_num_ = 1;
}
if (max_thread_num_ > kMaxThreadNum) {
max_thread_num_ = kMaxThreadNum;
}
}
void ThreadPool::SyncRunLoop() {
while (true) {
Task task;
{
std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
if (exit_run_) {
return;
}
task = task_queue_.front();
task_queue_.pop();
}
try {
task();
} catch (std::exception &e) {
MsException::Instance().SetException();
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
task_finished_count_ = task_finished_count_ + 1;
}
finished_cond_var_.notify_one();
}
}
bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
if (tasks.size() == 1) {
auto ret = tasks[0]();
return ret == SUCCESS;
}
std::unique_lock<std::mutex> lock(pool_mtx_);
exit_run_ = false;
size_t task_num = tasks.size();
size_t thread_num = sync_run_threads_.size();
if (thread_num < max_thread_num_ && thread_num < task_num) {
auto new_thread_num = max_thread_num_;
if (task_num < max_thread_num_) {
new_thread_num = task_num;
}
for (size_t i = thread_num; i < new_thread_num; ++i) {
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
}
}
for (auto &task : tasks) {
std::lock_guard<std::mutex> task_lock(task_mutex_);
task_queue_.push(task);
task_cond_var_.notify_one();
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
task_finished_count_ = 0;
}
return true;
}
ThreadPool &ThreadPool::GetInstance() {
static ThreadPool instance{};
return instance;
}
void ThreadPool::ClearThreadPool() {
std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
if (exit_run_) {
return;
}
exit_run_ = true;
task_cond_var_.notify_all();
for (auto &it : sync_run_threads_) {
if (it.joinable()) {
it.join();
}
}
sync_run_threads_.clear();
}
ThreadPool::~ThreadPool() {
try {
ClearThreadPool();
} catch (...) {
}
}
}
}