* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* openUBMC is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <algorithm>
#include <boost/asio/post.hpp>
#include <boost/asio/recycling_allocator.hpp>
#include <mc/log.h>
#include <mc/runtime/thread_pool.h>
namespace mc::runtime {
thread_local thread_pool::shard_t* t_current_shard = nullptr;
struct shard_scope {
shard_scope(thread_pool::shard_t* shard) {
t_current_shard = shard;
}
~shard_scope() {
t_current_shard = nullptr;
}
};
namespace {
boost::asio::detail::scheduler& get_or_create_scheduler(boost::asio::execution_context& ctx) {
if (boost::asio::has_service<boost::asio::detail::scheduler>(ctx)) {
return boost::asio::use_service<boost::asio::detail::scheduler>(ctx);
}
auto& s = boost::asio::make_service<boost::asio::detail::scheduler>(ctx, false);
s.init_task();
s.work_started();
return s;
}
std::string make_thread_name(const std::string& base_name, std::size_t index) {
std::string suffix = "-" + std::to_string(index);
std::string tname = base_name;
if (tname.length() + suffix.length() > 15) {
int keep_len = 15 - static_cast<int>(suffix.length()) - 2;
if (keep_len > 0) {
tname = tname.substr(0, keep_len) + "..";
} else {
tname = tname.substr(0, 15 - suffix.length());
}
}
tname += suffix;
return tname;
}
}
thread_pool::thread_pool(std::size_t num_threads, const std::string& name)
: boost::asio::execution_context(), m_scheduler(get_or_create_scheduler(*this)), m_num_threads(num_threads), m_name(name) {
MC_ASSERT_THROW(m_num_threads < NULL_INDEX, mc::runtime_exception, "线程数不能超过 {}", NULL_INDEX - 1);
if (m_num_threads > 0) {
m_shards.reserve(m_num_threads);
for (std::size_t i = 0; i < m_num_threads; ++i) {
auto shard = std::make_unique<shard_t>(*this);
shard->ctx = std::make_unique<io_context>();
shard->work = std::make_unique<boost::asio::executor_work_guard<io_context::executor_type>>(
boost::asio::make_work_guard(shard->ctx->get_executor()));
shard->next.store(NULL_INDEX, std::memory_order_relaxed);
shard->state.store(State::Running, std::memory_order_relaxed);
shard->id = i;
m_shards.push_back(std::move(shard));
}
}
m_idle_head.store(NULL_INDEX, std::memory_order_relaxed);
}
thread_pool::~thread_pool() {
stop();
join();
}
thread_pool::executor_type thread_pool::get_executor() noexcept {
return executor_type(*this);
}
void thread_pool::start() {
if (!m_threads.empty()) {
return;
}
m_threads.reserve(m_num_threads);
for (std::size_t i = 0; i < m_num_threads; ++i) {
m_threads.emplace_back([this, i]() {
set_current_thread_name(make_thread_name(m_name, i));
dlog("${name} 线程 ${idx} 启动", ("name", m_name)("idx", i));
try {
worker_loop(m_shards[i].get());
} catch (const std::exception& e) {
elog("${name} 线程 ${idx} 异常: ${error}", ("name", m_name)("idx", i)("error", e.what()));
}
dlog("${name} 线程 ${idx} 结束", ("name", m_name)("idx", i));
});
}
}
void thread_pool::push_idle(shard_t* shard) {
uint64_t old_head = m_idle_head.load(std::memory_order_relaxed);
uint64_t new_head;
do {
uint16_t old_index = static_cast<uint16_t>(old_head & INDEX_MASK);
shard->next.store(old_index, std::memory_order_relaxed);
uint64_t new_tag = (old_head & ~INDEX_MASK) + TAG_INCREMENT;
new_head = new_tag | shard->id;
} while (!m_idle_head.compare_exchange_weak(old_head, new_head,
std::memory_order_release,
std::memory_order_relaxed));
}
bool thread_pool::has_work(shard_t* shard) {
if (m_pending_tasks.load(std::memory_order_relaxed) > 0) {
return true;
}
if (shard->ctx->poll() > 0) {
return true;
}
return false;
}
bool thread_pool::poll_tasks(shard_t* shard) {
auto& ctx = *shard->ctx;
bool ran_local = ctx.poll() > 0;
boost::system::error_code ec;
bool ran_global = m_scheduler.poll_one(ec) > 0;
if (ran_global) {
int64_t old_val = m_pending_tasks.load(std::memory_order_relaxed);
while (old_val > 0) {
if (m_pending_tasks.compare_exchange_weak(old_val, old_val - 1, std::memory_order_relaxed)) {
break;
}
}
}
return ran_local || ran_global;
}
void thread_pool::worker_loop(shard_t* shard) {
auto& ctx = *shard->ctx;
shard_scope scope(shard);
poll_until(shard, [this]() {
return m_scheduler.stopped();
});
}
thread_pool::shard_t* thread_pool::acquire_idle_worker() {
uint64_t old_head = m_idle_head.load(std::memory_order_acquire);
while (true) {
uint16_t old_index = static_cast<uint16_t>(old_head & INDEX_MASK);
if (old_index == NULL_INDEX) {
return nullptr;
}
shard_t* node = m_shards[old_index].get();
uint16_t next_index = node->next.load(std::memory_order_relaxed);
uint64_t new_tag = (old_head & ~INDEX_MASK) + TAG_INCREMENT;
uint64_t new_head = new_tag | next_index;
if (m_idle_head.compare_exchange_weak(old_head, new_head,
std::memory_order_acquire,
std::memory_order_acquire)) {
node->in_stack.store(false, std::memory_order_release);
State expected = State::Idle;
if (node->state.compare_exchange_strong(expected, State::Notified,
std::memory_order_release,
std::memory_order_relaxed)) {
return node;
}
old_head = m_idle_head.load(std::memory_order_acquire);
}
}
}
static void empty_cb() {
}
void thread_pool::wake_up_workers() {
if (m_shards.empty()) {
return;
}
if (auto* shard = acquire_idle_worker()) {
boost::asio::post(
shard->ctx->get_executor(),
boost::asio::bind_allocator(boost::asio::recycling_allocator<void>(), empty_cb));
return;
}
}
void thread_pool::stop() {
if (stopped()) {
return;
}
m_scheduler.stop();
for (auto& shard : m_shards) {
shard->work.reset();
shard->ctx->stop();
}
}
void thread_pool::join() {
for (auto& t : m_threads) {
if (t.joinable()) {
t.join();
}
}
m_threads.clear();
}
bool thread_pool::stopped() const {
return m_scheduler.stopped();
}
std::size_t thread_pool::poll_one() {
boost::system::error_code ec;
return m_scheduler.poll_one(ec);
}
thread_pool::io_context& thread_pool::get_shard(std::size_t idx) {
if (m_shards.empty()) {
throw std::runtime_error("没有可用的 Worker 分片");
}
return *m_shards[idx % m_shards.size()]->ctx;
}
std::size_t thread_pool::shard_count() const {
return m_shards.size();
}
thread_pool::shard_t* thread_pool::get_current_shard() {
return t_current_shard;
}
void thread_pool::run() {
shard_t* shard = nullptr;
{
std::lock_guard<mc::sync::small_mutex> lock(m_attach_mutex);
if (t_current_shard != nullptr && &t_current_shard->pool == this) {
shard = t_current_shard;
} else {
MC_ASSERT_THROW(m_shards.size() < NULL_INDEX, mc::runtime_exception,
"线程数不能超过 ${max}", ("max", NULL_INDEX - 1));
auto new_shard = std::make_unique<shard_t>(*this);
new_shard->ctx = std::make_unique<io_context>();
new_shard->work = std::make_unique<boost::asio::executor_work_guard<io_context::executor_type>>(
boost::asio::make_work_guard(new_shard->ctx->get_executor()));
new_shard->next.store(NULL_INDEX, std::memory_order_relaxed);
new_shard->state.store(State::Running, std::memory_order_relaxed);
new_shard->id = static_cast<uint16_t>(m_shards.size());
shard = new_shard.get();
m_shards.push_back(std::move(new_shard));
t_current_shard = shard;
dlog("外部线程附加到 ${name},shard ${idx}", ("name", m_name)("idx", shard->id));
}
}
dlog("${name} 外部线程 shard ${idx} 开始运行", ("name", m_name)("idx", shard->id));
try {
worker_loop(shard);
} catch (const std::exception& e) {
elog("${name} 外部线程 shard ${idx} 异常: ${error}",
("name", m_name)("idx", shard->id)("error", e.what()));
}
dlog("${name} 外部线程 shard ${idx} 结束", ("name", m_name)("idx", shard->id));
t_current_shard = nullptr;
}
}