* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "dummy_quota_manager.h"
#include "log.h"
using namespace mindie_llm;
using namespace std;
using namespace boost::concurrent;
DummyQuotaManager::DummyQuotaManager(int rank, int quota) : initQuota_(quota), rank_(rank) {
quotaLeft_.store(quota);
acrossProcSyncThread_ = thread([this]() { this->AcrossProcSyncThreadEntry_(); });
MINDIE_LLM_LOG_INFO("DummyQuotaManager init successfully. rank:" << rank_ << ", initQuota:" << initQuota_);
}
DummyQuotaManager::~DummyQuotaManager() {
threadNeedStop_.store(true);
if (acrossProcSyncThread_.joinable()) {
acrossProcSyncThread_.join();
}
MINDIE_LLM_LOG_INFO("DummyQuotaManager sync thread stopped.");
}
bool DummyQuotaManager::AcquireQuota(bool isDummy) {
size_t reqWaitLoop = 0;
if (threadNeedStop_.load()) {
return true;
}
if (!isDummy) {
while (quotaLeft_.load() == 0 && !threadNeedStop_.load()) {
Wakeup();
this_thread::sleep_for(chrono::milliseconds(1));
if (reqWaitLoop++ % REQ_WAIT_LOG_LOOP == 0) {
MINDIE_LLM_LOG_WARN(
"no quota available. wait wakup to restore quota. If you keep seeing this, there might "
"be bug in DummyTokenManager. reqWaitLoop:"
<< reqWaitLoop << "; dummy queue size:" << dummyBatchQueue_.size());
}
}
dummyBatchQueue_.push(isDummy);
quotaLeft_.fetch_sub(1, std::memory_order_seq_cst);
return true;
}
bool succ = false;
if (quotaLeft_.load() != 0) {
dummyBatchQueue_.push(isDummy);
quotaLeft_.fetch_sub(1, std::memory_order_seq_cst);
succ = true;
}
return succ;
}
void DummyQuotaManager::Wakeup() {
bool isReqComming = true;
if (reqCommingQueue_.empty()) {
reqCommingQueue_.push(isReqComming);
}
}
void DummyQuotaManager::AllGather_(bool iAmDummy, bool &allDummy, int &maxQuota, bool dummyBatchSync) {
vector<torch::Tensor> batchInfo;
batchInfo.emplace_back(torch::tensor({static_cast<int>(iAmDummy), quotaLeft_.load(),
static_cast<int>(dummyBatchSync), static_cast<int>(dummyBatchQueue_.size())},
torch::dtype(torch::kInt32).device(c10::kCPU)));
vector<vector<torch::Tensor>> batchInfos;
size_t cost = 0;
static size_t allCost = 0;
static size_t count = 0;
try {
auto start = high_resolution_clock::now();
batchInfos = ProcessGroup::GetInstance().AllGather(batchInfo);
auto end = high_resolution_clock::now();
cost = static_cast<size_t>(duration_cast<milliseconds>(end - start).count());
} catch (const std::exception &e) {
MINDIE_LLM_LOG_WARN("allgather standard exception caught: " << e.what() << ". downgrade to dummy-always.");
quotaLeft_.store(initQuota_);
threadNeedStop_.store(true);
return;
} catch (...) {
MINDIE_LLM_LOG_WARN(
"allgather got unknown exception, DummyTokenManager cannot work, downgrade "
"to dummy-always.");
quotaLeft_.store(initQuota_);
threadNeedStop_.store(true);
return;
}
int preGatherType = 0;
int cnt = 0;
int maxDummyBatchQueueSize = 0;
for (const auto &tensor : batchInfos[0]) {
bool isDummy = tensor[0].item<bool>();
allDummy = allDummy && isDummy;
maxDummyBatchQueueSize = max(maxDummyBatchQueueSize, tensor[3].item<int>());
maxQuota = max(maxQuota, tensor[1].item<int>());
if ((cnt++ > 0 && preGatherType != tensor[2].item<int>()) ||
maxDummyBatchQueueSize > MAX_ALLOWED_DUMMY_QUEUE_PENDINGS) {
MINDIE_LLM_LOG_WARN("All-gather message type is not same or dummy batch too many. rank"
<< rank_ << ", is dummyBatchSync:" << tensor[2].item<int>() << ", err rank:" << cnt--
<< ", maxDummyBatchQueueSize:" << maxDummyBatchQueueSize);
quotaLeft_.store(initQuota_);
threadNeedStop_.store(true);
return;
}
preGatherType = tensor[2].item<int>();
}
allCost += cost;
count++;
if (cost > 200) {
MINDIE_LLM_LOG_INFO_REQUEST("dummy all gather cost too long, dp:"
<< rank_ << ", time:" << cost << ", quotaLeft_:" << quotaLeft_.load()
<< ", iAmDummy:" << iAmDummy << ", dummyQueue size:" << dummyBatchQueue_.size()
<< ", allDummy:" << allDummy << ", maxQuota:" << maxQuota
<< ", average cost:" << (allCost / count));
}
}
void DummyQuotaManager::QuotaAllGather_(bool &iAmDummy, bool &allDummy, int &maxQuota) {
iAmDummy = dummyBatchQueue_.pull();
AllGather_(iAmDummy, allDummy, maxQuota, true);
}
void DummyQuotaManager::WakeupAllGather_(bool &iAmDummy, bool &allDummy, int &maxQuota) {
iAmDummy = reqCommingQueue_.empty();
if (!iAmDummy) {
bool val = false;
while (reqCommingQueue_.try_pull(val) == queue_op_status::success) {
}
}
AllGather_(iAmDummy, allDummy, maxQuota, false);
}
void DummyQuotaManager::AcrossProcSyncThreadEntry_() {
int cnt = 0;
while (!threadNeedStop_.load()) {
if (cnt % 1000 == 0) {
MINDIE_LLM_LOG_INFO_REQUEST("DummyQuotaManager across process synchronization thread is live. dprank:"
<< rank_ << "; quotaLeft_:" << quotaLeft_.load()
<< "; dummyBatchQueue_ size:" << dummyBatchQueue_.size());
}
cnt++;
bool needDummySync = (quotaLeft_.load() == 0 && dummyBatchQueue_.size() != 0) || quotaLeft_.load() != 0;
if (needDummySync) {
bool allDummy = true;
bool iAmDummy = true;
int maxQuota = 0;
QuotaAllGather_(iAmDummy, allDummy, maxQuota);
int quotaToFill = max(0, initQuota_ - maxQuota);
if (!allDummy) {
quotaLeft_.fetch_add(quotaToFill, memory_order_seq_cst);
}
} else {
bool allDummy = true;
bool iAmDummy = true;
int maxQuota = 0;
WakeupAllGather_(iAmDummy, allDummy, maxQuota);
int quotaToFill = max(0, initQuota_ - maxQuota);
if (!allDummy) {
quotaLeft_.fetch_add(quotaToFill, memory_order_seq_cst);
}
}
}
MINDIE_LLM_LOG_INFO("DummyQuotaManager across process synchronization thread exits.");
}
void DummyQuotaManager::WaitForQuota_() {
unique_lock<mutex> lock(notifyMutex_);
notifyCv_.wait(lock, [this] { return quotaLeft_.load() != 0 || threadNeedStop_.load(); });
}
void DummyQuotaManager::NotifyQuotaAvailable_() { notifyCv_.notify_one(); }