* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "grpc_communicator.h"
#include <grpcpp/channel.h>
#include <grpcpp/create_channel.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/server.h>
#include <grpcpp/server_builder.h>
#include <algorithm>
#include <experimental/filesystem>
using grpc::Server;
using grpc::ServerBuilder;
namespace fs = std::experimental::filesystem;
namespace mindie_llm {
std::shared_ptr<GRPCCommunicator> GRPCCommunicator::grpcCommunicatorSingleton = nullptr;
static constexpr uint64_t SLAVE_NPU_REPORT_STALE_MS = 7000;
static constexpr uint64_t SLAVE_NPU_TIMEOUT_LOG_INTERVAL_MS = 60000;
static constexpr uint64_t NPU_REPORT_DIAG_LOG_INTERVAL_MS = 30000;
bool ReadFileToString(const fs::path &filePath, std::string &outContent) {
std::string path = filePath.string();
if (!CanonicalPath(path)) {
MINDIE_LLM_LOG_ERROR("Invalid Path: " + path);
return false;
}
std::ifstream file(path);
if (!file) {
MINDIE_LLM_LOG_ERROR("Cannot open file: " + path);
return false;
}
std::stringstream buf;
buf << file.rdbuf();
outContent = buf.str();
return true;
}
constexpr auto MAX_CONTACT_THREAD_NUM = 100;
std::shared_ptr<GRPCCommunicator> GRPCCommunicator::GetInstance(
const std::unordered_map<std::string, std::string> &modelConfig) {
static std::shared_ptr<GRPCCommunicator> instance = std::make_shared<GRPCCommunicator>(modelConfig);
GRPCCommunicator::grpcCommunicatorSingleton = instance;
return instance;
}
GRPCCommunicator::GRPCCommunicator(const std::unordered_map<std::string, std::string> &modelConfig) {
isMaster_ = modelConfig.at("isMaster") == "1";
std::vector<std::string> slaveIPs;
mindie_llm::Split(modelConfig.at("slaveIPs"), ",", slaveIPs);
slaveCount_ = slaveIPs.size();
masterIP_ = modelConfig.at("masterIP");
multiNodesInferPort_ = modelConfig.at("multiNodesInferPort");
slaveIp_ = modelConfig.at("localIP");
isDmiInfer_ = (modelConfig.count("is_dmi_infer") != 0) && (modelConfig.at("is_dmi_infer") == "1");
std::string homePath;
GetHomePath(homePath);
auto it = modelConfig.find("interNodeTLSEnabled");
interNodeTLSEnabled_ = (it != modelConfig.end() && it->second == "1");
if (interNodeTLSEnabled_) {
interNodeTlsCaPath_ = fs::path(homePath) / modelConfig.at("interNodeTlsCaPath");
mindie_llm::Split(modelConfig.at("interNodeTlsCaFiles"), ",", interNodeTlsCaFiles_);
interNodeTlsCert_ = fs::path(homePath) / modelConfig.at("interNodeTlsCert");
interNodeTlsPk_ = fs::path(homePath) / modelConfig.at("interNodeTlsPk");
interNodeTlsCrlPath_ = fs::path(homePath) / modelConfig.at("interNodeTlsCrlPath");
mindie_llm::Split(modelConfig.at("interNodeTlsCrlFiles"), ",", interNodeTlsCrlFiles_);
if (!LoadCertificates()) {
MINDIE_LLM_LOG_ERROR("Failed to load TLS certificates. Shutting down.");
throw std::runtime_error("Failed to load TLS certificates");
}
}
}
void GRPCCommunicator::StopServer() {
if (server_) {
server_->Shutdown();
}
if (masterWorkerThread_.joinable()) {
masterWorkerThread_.join();
}
MINDIE_LLM_LOG_INFO("gRPC server shutdown complete");
}
void GRPCCommunicator::StopClient() {
if (slaveStream_) {
slaveStream_->WritesDone();
grpc::Status status = slaveStream_->Finish();
if (!status.ok()) {
MINDIE_LLM_LOG_ERROR("Stream shutdown error: " + status.error_message());
}
}
StopModelInitHandlerThreads();
if (slaveWorkerThread_.joinable()) {
slaveWorkerThread_.join();
}
slaveStream_.reset();
stub_.reset();
channel_.reset();
MINDIE_LLM_LOG_INFO("gRPC connection shutdown complete");
}
GRPCCommunicator::~GRPCCommunicator() {
MINDIE_LLM_LOG_INFO("GRPCCommunicator Starting destruction");
if (isMaster_) {
StopServer();
} else {
StopClient();
}
MINDIE_LLM_LOG_INFO("GRPCCommunicator destruction completed");
}
bool GRPCCommunicator::Init(int initCount) {
grpcCommunicatorNum_ = initCount;
int oldCallInitCount = callInitCount_.fetch_add(1, std::memory_order_acq_rel);
if (oldCallInitCount == initCount - 1) {
MINDIE_LLM_LOG_INFO("Start to init GRPCCommunicator");
if (isMaster_) {
return InitMaster(initCount);
} else {
return InitSlave();
}
} else {
if (isMaster_) {
WaitForAllSlavesConnected();
}
}
return true;
}
bool GRPCCommunicator::InitMaster(int respHandlerThreadCount) {
MINDIE_LLM_LOG_INFO("GRPCCommunicator: Start to init as Master");
service_ = std::make_shared<MasterServiceImpl>(this, respHandlerThreadCount);
masterWorkerThread_ = std::thread([this]() {
std::string localAddr = FormatGrpcAddress(masterIP_, multiNodesInferPort_);
pthread_setname_np(pthread_self(), "GRPCServer");
ServerBuilder builder;
builder.AddChannelArgument(GRPC_ARG_MAX_CONCURRENT_STREAMS, maxConcurrentStreams);
builder.SetMaxReceiveMessageSize(grpcSendReceiveBufSize);
builder.SetMaxSendMessageSize(grpcSendReceiveBufSize);
std::shared_ptr<grpc::ServerCredentials> creds;
if (interNodeTLSEnabled_) {
std::vector<grpc::experimental::IdentityKeyCertPair> identityKeyCertPairList = {
{tlsCertPrivateKey_, tlsCert_}};
std::shared_ptr<grpc::experimental::CertificateProviderInterface> certificateProvider =
std::make_shared<grpc::experimental::StaticDataCertificateProvider>(caCert_, identityKeyCertPairList);
grpc::experimental::TlsServerCredentialsOptions tlsServerOpts(certificateProvider);
tlsServerOpts.set_cert_request_type(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
if (!interNodeTlsCrlPath_.empty() && !interNodeTlsCrlFiles_.empty()) {
std::vector<std::string> crlContentVec;
for (const auto &crlFile : interNodeTlsCrlFiles_) {
fs::path crlPath = fs::path(interNodeTlsCrlPath_) / crlFile;
std::string crlContent;
ReadFileToString(crlPath, crlContent);
crlContentVec.emplace_back(crlContent);
}
if (!crlContentVec.empty()) {
auto crlProviderSpan = grpc_core::experimental::CreateStaticCrlProvider(crlContentVec);
auto crlProvider = crlProviderSpan.value_or(nullptr);
if (crlProvider == nullptr) {
MINDIE_LLM_LOG_ERROR("Failed to create crl provider");
return false;
}
tlsServerOpts.set_crl_provider(crlProvider);
}
}
tlsServerOpts.watch_root_certs();
tlsServerOpts.watch_identity_key_cert_pairs();
creds = grpc::experimental::TlsServerCredentials(tlsServerOpts);
} else {
creds = grpc::InsecureServerCredentials();
}
builder.AddListeningPort(localAddr, creds);
builder.RegisterService(service_.get());
server_ = builder.BuildAndStart();
if (!server_) {
MINDIE_LLM_LOG_ERROR("Failed to start gRPC server on port " + multiNodesInferPort_);
return false;
}
MINDIE_LLM_LOG_INFO("gRPC server started on port " + multiNodesInferPort_ + " with " +
(interNodeTLSEnabled_ ? "TLS" : "no encryption"));
server_->Wait();
return true;
});
std::shared_ptr<MasterServiceImpl> masterService = std::static_pointer_cast<MasterServiceImpl>(service_);
for (int dpRankIdx : responseHandlers_.KeySet()) {
masterService->DPRankIdxToSyncResp().Insert(dpRankIdx, std::make_shared<ExecRespBlockingQueue>());
}
MINDIE_LLM_LOG_INFO("GRPCCommunicator: wait slave connecting...");
WaitForAllSlavesConnected();
MINDIE_LLM_LOG_INFO("GRPCCommunicator: All " + std::to_string(slaveCount_) + " slaves connected");
return true;
}
bool GRPCCommunicator::InitSlave() {
MINDIE_LLM_LOG_INFO("GRPCCommunicator: Start to init as Slave (IP=" + slaveIp_ + ")");
int retryCount = 0;
int sleepInterval = 1;
int maxRetries = 120;
bool connected = false;
while (retryCount++ < maxRetries) {
try {
MINDIE_LLM_LOG_INFO("GRPCCommunicator: attempting connection to server at IP = " + masterIP_ +
", port = " + multiNodesInferPort_);
grpc::ChannelArguments channelArgs;
channelArgs.SetInt(GRPC_ARG_MAX_CONCURRENT_STREAMS, maxConcurrentStreams);
channelArgs.SetInt(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, grpcSendReceiveBufSize);
channelArgs.SetInt(GRPC_ARG_MAX_SEND_MESSAGE_LENGTH, grpcSendReceiveBufSize);
std::shared_ptr<grpc::ChannelCredentials> creds;
if (interNodeTLSEnabled_) {
std::vector<grpc::experimental::IdentityKeyCertPair> identityKeyCertPairList = {
{tlsCertPrivateKey_, tlsCert_}};
std::shared_ptr<grpc::experimental::CertificateProviderInterface> certificateProvider =
std::make_shared<grpc::experimental::StaticDataCertificateProvider>(caCert_,
identityKeyCertPairList);
auto tlsChannelOpts = std::make_unique<grpc::experimental::TlsChannelCredentialsOptions>();
tlsChannelOpts->set_certificate_provider(certificateProvider);
if (!interNodeTlsCrlPath_.empty() && !interNodeTlsCrlFiles_.empty()) {
std::vector<std::string> crlContentVec;
for (const auto &crlFile : interNodeTlsCrlFiles_) {
fs::path crlPath = fs::path(interNodeTlsCrlPath_) / crlFile;
std::string crlContent;
ReadFileToString(crlPath, crlContent);
crlContentVec.emplace_back(crlContent);
}
if (!crlContentVec.empty()) {
auto crlProviderSpan = grpc_core::experimental::CreateStaticCrlProvider(crlContentVec);
auto crlProvider = crlProviderSpan.value_or(nullptr);
if (crlProvider == nullptr) {
MINDIE_LLM_LOG_ERROR("Failed to create crl provider");
return false;
}
tlsChannelOpts->set_crl_provider(crlProvider);
}
}
tlsChannelOpts->watch_root_certs();
tlsChannelOpts->watch_identity_key_cert_pairs();
creds = grpc::experimental::TlsCredentials(*tlsChannelOpts);
} else {
creds = grpc::InsecureChannelCredentials();
}
channel_ =
grpc::CreateCustomChannel(FormatGrpcAddress(masterIP_, multiNodesInferPort_), creds, channelArgs);
stub_ = MasterService::NewStub(channel_);
context_ = std::make_unique<grpc::ClientContext>();
slaveStream_ = stub_->RegisterAndCommunicate(context_.get());
if (!slaveStream_) {
MINDIE_LLM_LOG_WARN("Failed to establish bidirectional stream to master. "
<< "The master may not be ready yet. Retrying in 1 second...");
std::this_thread::sleep_for(std::chrono::seconds(sleepInterval));
continue;
}
MINDIE_LLM_LOG_INFO("Successfully connected to master and obtained slaveStream");
if (SendRegistration()) {
MINDIE_LLM_LOG_INFO("Registration succeeded");
connected = true;
break;
} else {
slaveStream_->WritesDone();
slaveStream_->Finish();
MINDIE_LLM_LOG_WARN("Send registration message to master failed. Retrying in 1 second...");
std::this_thread::sleep_for(std::chrono::seconds(sleepInterval));
}
} catch (const std::exception &e) {
MINDIE_LLM_LOG_ERROR("gRPC CreateChannel Error: " + std::string(e.what()));
} catch (...) {
MINDIE_LLM_LOG_ERROR("gRPC CreateChannel failed with unknown exception");
}
}
if (!connected) {
MINDIE_LLM_LOG_ERROR("Failed to establish connection to master after maximum retries");
return false;
}
StartWorkerThread();
return true;
}
void GRPCCommunicator::WaitForAllSlavesConnected() {
std::mutex mtx;
std::unique_lock<std::mutex> lock(mtx);
cv_.wait(lock, [this] { return AllSlavesConnected(); });
}
bool GRPCCommunicator::SendRegistration() {
SlaveToMasterMsg msg;
auto *reg = msg.mutable_register_request();
reg->set_slave_ip(slaveIp_);
MINDIE_LLM_LOG_INFO("Sent registration to master: slave_ip=" + slaveIp_);
if (slaveStream_->Write(msg)) {
return true;
} else {
return false;
}
}
void GRPCCommunicator::StartWorkerThread() {
if (isDmiInfer_ && !modelInitHandlerActive_.load(std::memory_order_relaxed)) {
int modelInitThreadCount = grpcCommunicatorNum_;
modelInitHandlerActive_.store(true, std::memory_order_relaxed);
modelInitHandlerThreads_.reserve(modelInitThreadCount);
for (int i = 0; i < modelInitThreadCount; ++i) {
modelInitHandlerThreads_.emplace_back([this] { ModelInitHandlerLoop(); });
pthread_setname_np(modelInitHandlerThreads_.back().native_handle(), "GRPCModelInit");
}
}
slaveWorkerThread_ = std::thread([this] {
pthread_setname_np(pthread_self(), "GRPCSlave");
MasterToSlaveMsg task;
try {
while (slaveStream_->Read(&task)) {
int targetDPRank = task.target_dp_rank();
ExecuteRequest request = task.execute_request();
if (request.execute_type() == REMOTE_MODEL_INIT) {
pendingModelInitQueue_.push(std::make_shared<MasterToSlaveMsg>(std::move(task)));
} else {
HandleRequestFromMaster(request, targetDPRank);
}
}
MINDIE_LLM_LOG_INFO("gRPC Slave Worker Thread: stream closed by server");
} catch (const std::exception &e) {
MINDIE_LLM_LOG_ERROR("gRPC Slave Worker Thread Exception: " + std::string(e.what()));
} catch (...) {
MINDIE_LLM_LOG_ERROR("gRPC Slave Worker Thread unknown exception");
}
});
}
void GRPCCommunicator::ModelInitHandlerLoop() {
while (modelInitHandlerActive_.load(std::memory_order_relaxed)) {
std::shared_ptr<MasterToSlaveMsg> task = pendingModelInitQueue_.pull();
int targetDPRank = task->target_dp_rank();
ExecuteRequest request = task->execute_request();
HandleRequestFromMaster(request, targetDPRank);
}
MINDIE_LLM_LOG_ERROR("Slave ModelInitHandlerLoop exit.");
}
void GRPCCommunicator::StopModelInitHandlerThreads() {
modelInitHandlerActive_.store(false, std::memory_order_relaxed);
pendingModelInitQueue_.close();
for (auto &thread : modelInitHandlerThreads_) {
if (thread.joinable()) {
thread.join();
}
}
modelInitHandlerThreads_.clear();
}
template <typename StreamType, typename MsgType>
bool GRPCCommunicator::SafeWriteMsgToStream(StreamType stream, const MsgType &msg) {
if (!stream) {
MINDIE_LLM_LOG_ERROR("SafeWriteMsgToStream: stream is null (cannot write message)");
return false;
}
std::lock_guard<std::mutex> lock(streamWriteMutex_);
if (!stream->Write(msg)) {
MINDIE_LLM_LOG_ERROR("SafeWriteMsgToStream: failed to write message to stream");
return false;
}
return true;
}
bool GRPCCommunicator::SendRequest(ExecuteRequest &request, int sourceDPRank, int targetDPRank,
const std::string &slaveIp) {
if (sourceDPRank < 0 || targetDPRank < 0) {
MINDIE_LLM_LOG_ERROR("SendRequest: sourceDPRank and targetDPRank must be non-negative integers.");
return false;
}
MasterToSlaveMsg msg;
msg.set_source_dp_rank(sourceDPRank);
msg.set_target_dp_rank(targetDPRank);
*msg.mutable_execute_request() = request;
if (slaveIp.empty()) {
for (std::optional<SlaveStreamPtr> stream : slaveIpToStream_.Values()) {
if (!SafeWriteMsgToStream(stream.value_or(nullptr), msg)) {
return false;
}
}
} else {
std::optional<SlaveStreamPtr> stream = slaveIpToStream_.Get(slaveIp);
if (!SafeWriteMsgToStream(stream.value_or(nullptr), msg)) {
return false;
}
}
return true;
}
bool GRPCCommunicator::GetSyncResponse(ExecuteResponse &response, int sourceDPRank) {
std::shared_ptr<MasterServiceImpl> masterService = std::static_pointer_cast<MasterServiceImpl>(service_);
return masterService->Take(sourceDPRank, response);
}
bool GRPCCommunicator::SendResponse(ExecuteResponse &response, int sourceDPRank, int targetDPRank) {
if (sourceDPRank < 0 || targetDPRank < 0) {
MINDIE_LLM_LOG_ERROR("SendResponse: sourceDPRank and targetDPRank must be non-negative integers.");
return false;
}
SlaveToMasterMsg msg;
msg.set_source_dp_rank(sourceDPRank);
msg.set_target_dp_rank(targetDPRank);
*msg.mutable_execute_response() = response;
if (!SafeWriteMsgToStream(slaveStream_.get(), msg)) {
MINDIE_LLM_LOG_ERROR("SendResponse: failed to write response to slave stream.");
return false;
}
return true;
}
bool GRPCCommunicator::SendNpuUtilizationReport(uint32_t maxAicoreUtilizationPercent) {
if (isMaster_) {
return false;
}
SlaveToMasterMsg msg;
msg.set_source_dp_rank(0);
msg.set_target_dp_rank(0);
msg.mutable_npu_util_report()->set_max_aicore_utilization_percent(maxAicoreUtilizationPercent);
std::lock_guard<std::mutex> lock(streamWriteMutex_);
if (!slaveStream_) {
return false;
}
return slaveStream_->Write(msg);
}
void GRPCCommunicator::RecordSlaveNpuUtil(const std::string &slaveIp, uint32_t maxAicoreUtilizationPercent) {
std::lock_guard<std::mutex> lock(slaveNpuMutex_);
slaveIpToMaxNpuUtil_[slaveIp] = {maxAicoreUtilizationPercent, std::chrono::steady_clock::now()};
++slaveNpuReportRxCount_;
}
uint32_t GRPCCommunicator::GetSlaveMaxNpuUtilizationPercent() const {
std::lock_guard<std::mutex> lock(slaveNpuMutex_);
const auto now = std::chrono::steady_clock::now();
const auto expireDuration = std::chrono::milliseconds(SLAVE_NPU_REPORT_STALE_MS);
slaveNpuReportTimeout_ = false;
uint32_t maxVal = 0;
uint32_t freshSamples = 0;
uint32_t staleSamples = 0;
for (const auto &kv : slaveIpToMaxNpuUtil_) {
if (now - kv.second.reportTime <= expireDuration) {
maxVal = std::max(maxVal, kv.second.maxAicoreUtilizationPercent);
++freshSamples;
} else {
++staleSamples;
}
}
if (staleSamples > 0 || freshSamples < slaveCount_) {
slaveNpuReportTimeout_ = true;
const bool enteringTimeout = !slaveNpuTimeoutActive_;
const bool intervalElapsed =
(now - lastSlaveNpuTimeoutLogTime_) >= std::chrono::milliseconds(SLAVE_NPU_TIMEOUT_LOG_INTERVAL_MS);
if (enteringTimeout || intervalElapsed) {
lastSlaveNpuTimeoutLogTime_ = now;
MINDIE_LLM_LOG_WARN("Slave NPU reports stale/missing on master. stale="
<< staleSamples << ", fresh_reported=" << freshSamples
<< ", cached_samples=" << slaveIpToMaxNpuUtil_.size() << ", expected=" << slaveCount_
<< ", registered_streams=" << slaveIpToStream_.Size());
}
slaveNpuTimeoutActive_ = true;
} else {
slaveNpuTimeoutActive_ = false;
}
return maxVal;
}
bool GRPCCommunicator::ConsumeSlaveNpuReportTimeoutFlag() const {
std::lock_guard<std::mutex> lock(slaveNpuMutex_);
const auto now = std::chrono::steady_clock::now();
if ((now - lastMasterNpuDiagLogTime_) >= std::chrono::milliseconds(NPU_REPORT_DIAG_LOG_INTERVAL_MS)) {
lastMasterNpuDiagLogTime_ = now;
const uint64_t rxDelta = slaveNpuReportRxCount_ - lastSlaveNpuReportRxCountLog_;
lastSlaveNpuReportRxCountLog_ = slaveNpuReportRxCount_;
MINDIE_LLM_LOG_INFO("Master NPU report diagnostics: registered_streams="
<< slaveIpToStream_.Size() << ", expected_slaves=" << slaveCount_
<< ", total_rx=" << slaveNpuReportRxCount_ << ", rx_delta_since_last=" << rxDelta);
}
const bool ret = slaveNpuReportTimeout_;
slaveNpuReportTimeout_ = false;
return ret;
}
template <typename HandlerType>
bool RegisterHandler(ConcurrentMap<int, HandlerType> &handlers, int dpRankIdx, HandlerType handler) {
if (handler == nullptr) {
MINDIE_LLM_LOG_ERROR("GRPC RegisterHandler: handler is null.");
return false;
}
if (handlers.Count(dpRankIdx) > 0) {
MINDIE_LLM_LOG_ERROR("GRPC RegisterHandler: handler for dpRankIdx " << dpRankIdx << " is already registered.");
return false;
}
handlers.Insert(dpRankIdx, handler);
return true;
}
bool GRPCCommunicator::RegisterRequestHandler(RequestHandler handler, int dpRankIdx) {
return RegisterHandler(requestHandlers_, dpRankIdx, handler);
}
bool GRPCCommunicator::RegisterRecoverRequestHandler(RequestHandler handler, int dpRankIdx) {
return RegisterHandler(recoverRequestHandlers_, dpRankIdx, handler);
}
bool GRPCCommunicator::RegisterResponseHandler(ResponseHandler handler, int dpRankIdx) {
return RegisterHandler(responseHandlers_, dpRankIdx, handler);
}
bool GRPCCommunicator::HandleResponseFromSlave(ExecuteResponse &response, int targetDPRank) {
std::optional<ResponseHandler> optHandler = responseHandlers_.Get(targetDPRank);
if (!optHandler.has_value()) {
MINDIE_LLM_LOG_ERROR("HandleResponseFromSlave: response handler for targetDPRank "
<< targetDPRank << " is not set or does not exist.");
return false;
}
try {
optHandler.value()(response);
} catch (const std::exception &e) {
MINDIE_LLM_LOG_ERROR("HandleResponseFromSlave: exception occurred while handling response: " +
std::string(e.what()));
return false;
} catch (...) {
MINDIE_LLM_LOG_ERROR("HandleResponseFromSlave: unknown exception occurred while handling response.");
return false;
}
return true;
}
void GRPCCommunicator::HandleRequestFromMaster(ExecuteRequest &request, int targetDPRank) {
if (request.execute_type() == MODEL_INFER) {
std::vector<RequestHandler> handlers = requestHandlers_.Values();
for (const auto &handler : handlers) {
handler(request);
}
} else if (request.execute_type() == RECOVER_COMMAND_EXEC || request.execute_type() == START_COMMAND_EXEC ||
request.execute_type() == PAUSE_COMMAND_EXEC || request.execute_type() == CLEAR_COMMAND_EXEC) {
std::vector<RequestHandler> handlers = recoverRequestHandlers_.Values();
for (const auto &handler : handlers) {
handler(request);
}
} else {
std::optional<RequestHandler> optHandler = requestHandlers_.Get(targetDPRank);
if (!optHandler.has_value()) {
MINDIE_LLM_LOG_ERROR("GRPCCommunicator: request handler for targetDPRank "
<< targetDPRank << " is not set or does not exist.");
return;
}
optHandler.value()(request);
}
}
bool GRPCCommunicator::AllSlavesConnected() { return slaveIpToStream_.Size() >= slaveCount_; }
void GRPCCommunicator::NotifyAll() { cv_.notify_all(); }
ConcurrentMap<std::string, SlaveStreamPtr> &GRPCCommunicator::SlaveIpToStream() { return slaveIpToStream_; }
MasterServiceImpl::MasterServiceImpl(GRPCCommunicator *comm, int respHandlerThreadCount) : gRPCCommunicator_(comm) {
respHandlerThreads_.reserve(respHandlerThreadCount);
for (int i = 0; i < respHandlerThreadCount; ++i) {
respHandlerThreads_.emplace_back([this] { RespHandlerLoop(); });
pthread_setname_np(respHandlerThreads_.back().native_handle(), "GRPCResponseHandler");
}
}
MasterServiceImpl::~MasterServiceImpl() { StopRespHandlerThreads(); }
void MasterServiceImpl::RespHandlerLoop() {
while (respHandlerThreadActive_.load(std::memory_order_relaxed)) {
std::shared_ptr<SlaveResponseTask> task = pendingRespFromSlaveQueue_.pull();
gRPCCommunicator_->HandleResponseFromSlave(*task->response, task->targetDPRank);
}
}
void MasterServiceImpl::StopRespHandlerThreads() {
respHandlerThreadActive_.store(false, std::memory_order_relaxed);
pendingRespFromSlaveQueue_.close();
for (auto &thread : respHandlerThreads_) {
if (thread.joinable()) {
thread.join();
}
}
respHandlerThreads_.clear();
}
grpc::Status MasterServiceImpl::RegisterAndCommunicate(ServerContext *context, SlaveStreamPtr stream) {
SlaveToMasterMsg client_msg;
std::string slaveIpFromStream;
while (stream->Read(&client_msg)) {
if (client_msg.has_register_request()) {
auto ®ister_request = client_msg.register_request();
slaveIpFromStream = register_request.slave_ip();
gRPCCommunicator_->SlaveIpToStream().Insert(register_request.slave_ip(), stream);
(void)context;
if (gRPCCommunicator_->AllSlavesConnected()) {
gRPCCommunicator_->NotifyAll();
}
} else if (client_msg.has_execute_response()) {
int targetDPRank = client_msg.target_dp_rank();
ExecuteResponse executeResponse = client_msg.execute_response();
if (executeResponse.msg_type() == REMOTE_MODEL_INIT || executeResponse.msg_type() == PD_LINK_STATUS_QUERY ||
executeResponse.msg_type() == LORA_OPERATION || executeResponse.msg_type() == RECOVER_COMMAND_EXEC ||
executeResponse.msg_type() == START_COMMAND_EXEC || executeResponse.msg_type() == PAUSE_COMMAND_EXEC ||
executeResponse.msg_type() == CLEAR_COMMAND_EXEC) {
dpRankIdxToSyncResp_.Get(targetDPRank)
.value()
->push(std::make_shared<ExecuteResponse>(std::move(executeResponse)));
} else {
std::shared_ptr<SlaveResponseTask> respTask = std::make_shared<SlaveResponseTask>();
respTask->targetDPRank = targetDPRank;
respTask->response = std::make_shared<ExecuteResponse>(std::move(executeResponse));
pendingRespFromSlaveQueue_.push(std::move(respTask));
}
} else if (client_msg.has_npu_util_report()) {
if (slaveIpFromStream.empty()) {
MINDIE_LLM_LOG_WARN("MasterService: npu_util_report received before register_request, ignored.");
continue;
}
uint32_t pct = client_msg.npu_util_report().max_aicore_utilization_percent();
gRPCCommunicator_->RecordSlaveNpuUtil(slaveIpFromStream, pct);
}
}
return grpc::Status::OK;
}
bool MasterServiceImpl::Take(int targetDPRank, ExecuteResponse &response) {
auto blockingQueueOpt = dpRankIdxToSyncResp_.Get(targetDPRank);
if (!blockingQueueOpt.has_value()) {
MINDIE_LLM_LOG_ERROR("No blocking queue found for targetDPRank " << targetDPRank);
return false;
}
response = *blockingQueueOpt.value()->pull();
return true;
}
ConcurrentMap<int, std::shared_ptr<ExecRespBlockingQueue>> &MasterServiceImpl::DPRankIdxToSyncResp() {
return dpRankIdxToSyncResp_;
}
bool GRPCCommunicator::LoadCertificates() {
MINDIE_LLM_LOG_INFO("Loading TLS certificates for mutual authentication...");
std::string homePath;
GetHomePath(homePath);
caCert_.clear();
for (const auto &caFile : interNodeTlsCaFiles_) {
fs::path caPath = fs::path(interNodeTlsCaPath_) / caFile;
ReadFileToString(caPath, caCert_);
caCert_ += "\n";
MINDIE_LLM_LOG_INFO("Loaded CA certificate: " + caPath.string());
}
fs::path certPath = interNodeTlsCert_;
ReadFileToString(certPath, tlsCert_);
MINDIE_LLM_LOG_INFO("Loaded server/client certificate: " + certPath.string());
fs::path keyPath = fs::path(interNodeTlsPk_);
std::string keyContent;
ReadFileToString(keyPath, keyContent);
tlsCertPrivateKey_.assign(keyContent.data(), keyContent.size());
return true;
}
}