* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef GRPC_COMMUNICATOR_H
#define GRPC_COMMUNICATOR_H
#include <grpcpp/server_builder.h>
#include <grpcpp/server_context.h>
#include <openssl/bio.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <boost/thread/sync_queue.hpp>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <mutex>
#include <thread>
#include <unordered_map>
#include "common_util.h"
#include "concurrent_map.h"
#include "executor/executor_interface.h"
#include "log.h"
#include "model_execute_data.grpc.pb.h"
#include "model_execute_data.pb.h"
#include "string_utils.h"
namespace mindie_llm {
using grpc::ServerContext;
using grpc::ServerReaderWriter;
using model_execute_data::MasterService;
using model_execute_data::MasterToSlaveMsg;
using model_execute_data::SlaveToMasterMsg;
using SlaveStreamPtr = ServerReaderWriter<MasterToSlaveMsg, SlaveToMasterMsg> *;
using ExecRespBlockingQueue = boost::sync_queue<std::shared_ptr<ExecuteResponse>>;
class GRPCCommunicator {
public:
static std::shared_ptr<GRPCCommunicator> GetInstance(
const std::unordered_map<std::string, std::string> &modelConfig);
static std::shared_ptr<GRPCCommunicator> TryGetInstance() { return grpcCommunicatorSingleton; }
[[nodiscard]] bool IsMaster() const noexcept { return isMaster_; }
bool SendNpuUtilizationReport(uint32_t maxAicoreUtilizationPercent);
uint32_t GetSlaveMaxNpuUtilizationPercent() const;
bool ConsumeSlaveNpuReportTimeoutFlag() const;
void RecordSlaveNpuUtil(const std::string &slaveIp, uint32_t maxAicoreUtilizationPercent);
GRPCCommunicator(const GRPCCommunicator &) = delete;
GRPCCommunicator &operator=(const GRPCCommunicator &) = delete;
explicit GRPCCommunicator(const std::unordered_map<std::string, std::string> &modelConfig);
~GRPCCommunicator();
bool Init(int initCount);
bool SendRequest(ExecuteRequest &request, int sourceDPRank, int targetDPRank,
const std::string &slaveIp = "");
bool GetSyncResponse(ExecuteResponse &response, int sourceDPRank);
bool SendResponse(ExecuteResponse &response, int sourceDPRank, int targetDPRank);
bool RegisterRequestHandler(RequestHandler handler, int dpRankIdx);
bool RegisterRecoverRequestHandler(RequestHandler handler, int dpRankIdx);
bool RegisterResponseHandler(ResponseHandler handler, int dpRankIdx);
void HandleRequestFromMaster(ExecuteRequest &request, int targetDPRank);
bool HandleResponseFromSlave(ExecuteResponse &response, int targetDPRank);
bool AllSlavesConnected();
void NotifyAll();
void StopServer();
void StopClient();
ConcurrentMap<std::string, SlaveStreamPtr> &SlaveIpToStream();
private:
static std::shared_ptr<GRPCCommunicator> grpcCommunicatorSingleton;
static constexpr int grpcSendReceiveBufSize = 256 * 1024 * 1024;
static constexpr int maxConcurrentStreams = 128;
bool InitMaster(int respHandlerThreadCount);
bool InitSlave();
void WaitForAllSlavesConnected();
void StartWorkerThread();
bool SendRegistration();
void ModelInitHandlerLoop();
void StopModelInitHandlerThreads();
bool LoadCertificates();
template <typename StreamType, typename MsgType>
bool SafeWriteMsgToStream(StreamType stream, const MsgType &msg);
bool interNodeTLSEnabled_;
std::string interNodeTlsCaPath_;
std::vector<std::string> interNodeTlsCaFiles_;
std::string interNodeTlsCert_;
std::string interNodeTlsPk_;
std::string interNodeTlsCrlPath_;
std::vector<std::string> interNodeTlsCrlFiles_;
std::string caCert_;
std::string tlsCert_;
std::string tlsCertPrivateKey_;
bool isMaster_ = false;
std::string masterIP_;
std::string multiNodesInferPort_;
std::atomic<int> callInitCount_{0};
std::mutex streamWriteMutex_;
ConcurrentMap<std::string, SlaveStreamPtr> slaveIpToStream_;
uint32_t slaveCount_{0};
std::condition_variable cv_;
std::thread masterWorkerThread_;
ConcurrentMap<int, ResponseHandler> responseHandlers_;
std::unique_ptr<grpc::Server> server_;
std::shared_ptr<MasterService::Service> service_;
std::string slaveIp_;
std::unique_ptr<
::grpc::ClientReaderWriter<::model_execute_data::SlaveToMasterMsg, ::model_execute_data::MasterToSlaveMsg>>
slaveStream_;
std::unique_ptr<MasterService::Stub> stub_;
std::shared_ptr<grpc::Channel> channel_;
std::thread slaveWorkerThread_;
std::unique_ptr<grpc::ClientContext> context_;
ConcurrentMap<int, RequestHandler> requestHandlers_;
ConcurrentMap<int, RequestHandler> recoverRequestHandlers_;
struct SlaveNpuSample {
uint32_t maxAicoreUtilizationPercent{0};
std::chrono::steady_clock::time_point reportTime{};
};
boost::sync_queue<std::shared_ptr<MasterToSlaveMsg>> pendingModelInitQueue_;
std::vector<std::thread> modelInitHandlerThreads_;
std::atomic<bool> modelInitHandlerActive_{false};
uint32_t grpcCommunicatorNum_{0};
bool isDmiInfer_{false};
mutable std::mutex slaveNpuMutex_;
std::unordered_map<std::string, SlaveNpuSample> slaveIpToMaxNpuUtil_;
mutable bool slaveNpuReportTimeout_{false};
mutable bool slaveNpuTimeoutActive_{false};
mutable std::chrono::steady_clock::time_point lastSlaveNpuTimeoutLogTime_{};
mutable uint64_t slaveNpuReportRxCount_{0};
mutable uint64_t lastSlaveNpuReportRxCountLog_{0};
mutable std::chrono::steady_clock::time_point lastMasterNpuDiagLogTime_{};
};
class MasterServiceImpl final : public MasterService::Service {
public:
explicit MasterServiceImpl(GRPCCommunicator *comm, int respHandlerThreadCount);
~MasterServiceImpl() override;
grpc::Status RegisterAndCommunicate(ServerContext *context, SlaveStreamPtr stream) override;
bool Take(int targetDPRank, ExecuteResponse &response);
ConcurrentMap<int, std::shared_ptr<ExecRespBlockingQueue>> &DPRankIdxToSyncResp();
private:
GRPCCommunicator *gRPCCommunicator_;
ConcurrentMap<int, std::shared_ptr<ExecRespBlockingQueue>> dpRankIdxToSyncResp_;
struct SlaveResponseTask {
int targetDPRank;
std::shared_ptr<ExecuteResponse> response;
};
boost::sync_queue<std::shared_ptr<SlaveResponseTask>> pendingRespFromSlaveQueue_;
std::vector<std::thread> respHandlerThreads_;
std::atomic<bool> respHandlerThreadActive_{true};
void RespHandlerLoop();
void StopRespHandlerThreads();
};
}
#endif