c8945568创建于 2021年9月17日历史提交
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "common/thread_pool.h"
#include <algorithm>
#include <exception>
#include "utils/log_adapter.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_exception.h"

namespace mindspore {
namespace common {
#if ENABLE_D || ENABLE_GPU
const size_t kDeviceNum = 8;
#endif
const size_t kMaxThreadNum = 23;

ThreadPool::ThreadPool() {
  size_t process_core_num = std::thread::hardware_concurrency() - 1;
  if (process_core_num < 1) {
    process_core_num = 1;
  }
#if ENABLE_D || ENABLE_GPU
  max_thread_num_ = process_core_num / kDeviceNum;
#else
  max_thread_num_ = process_core_num;
#endif
  if (max_thread_num_ < 1) {
    max_thread_num_ = 1;
  }
  if (max_thread_num_ > kMaxThreadNum) {
    max_thread_num_ = kMaxThreadNum;
  }
}

void ThreadPool::SyncRunLoop() {
  while (true) {
    Task task;
    {
      std::unique_lock<std::mutex> lock(task_mutex_);
      task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
      if (exit_run_) {
        return;
      }
      task = task_queue_.front();
      task_queue_.pop();
    }
    try {
      task();
    } catch (std::exception &e) {
      MsException::Instance().SetException();
    }
    {
      std::unique_lock<std::mutex> task_lock(task_mutex_);
      task_finished_count_ = task_finished_count_ + 1;
    }
    finished_cond_var_.notify_one();
  }
}

bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
  if (tasks.size() == 1) {
    auto ret = tasks[0]();
    return ret == SUCCESS;
  }
  std::unique_lock<std::mutex> lock(pool_mtx_);
  exit_run_ = false;
  size_t task_num = tasks.size();
  size_t thread_num = sync_run_threads_.size();
  if (thread_num < max_thread_num_ && thread_num < task_num) {
    auto new_thread_num = max_thread_num_;
    if (task_num < max_thread_num_) {
      new_thread_num = task_num;
    }
    for (size_t i = thread_num; i < new_thread_num; ++i) {
      sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
    }
  }

  for (auto &task : tasks) {
    std::lock_guard<std::mutex> task_lock(task_mutex_);
    task_queue_.push(task);
    task_cond_var_.notify_one();
  }
  {
    std::unique_lock<std::mutex> task_lock(task_mutex_);
    finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
    task_finished_count_ = 0;
  }
  return true;
}

ThreadPool &ThreadPool::GetInstance() {
  static ThreadPool instance{};
  return instance;
}

void ThreadPool::ClearThreadPool() {
  std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
  if (exit_run_) {
    return;
  }
  exit_run_ = true;
  task_cond_var_.notify_all();
  for (auto &it : sync_run_threads_) {
    if (it.joinable()) {
      it.join();
    }
  }
  sync_run_threads_.clear();
}

ThreadPool::~ThreadPool() {
  try {
    ClearThreadPool();
  } catch (...) {
    // exit
  }
}
}  // namespace common
}  // namespace mindspore