* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "common/message_handle/message_client.h"
#include "common/utils/rts_api_utils.h"
#include "common/util/mem_utils.h"
#include "common/data_flow/queue/heterogeneous_exchange_service.h"
#include "common/compile_profiling/ge_call_wrapper.h"
#include "acl/acl.h"
#include "common/df_chk.h"
namespace ge {
namespace {
constexpr int64_t kDefaultTimeout = 1200;
constexpr int32_t kDequeueTimeout = 3000;
constexpr int32_t kDequeueTimeoutInSec = 3;
constexpr int32_t kEnqueueTimeout = 30 * 1000;
constexpr int64_t kDefaultRetryTimes = 400;
constexpr uint32_t kMsgQueueDepth = 3U;
constexpr uint16_t kMsgTypeFinalize = 2U;
using HService = HeterogeneousExchangeService;
}
template <class Request, class Response>
MessageClient<Request, Response>::MessageClient(int32_t device_id, bool parallel_send)
: device_id_(device_id), parallel_send_(parallel_send) {}
template <class Request, class Response>
MessageClient<Request, Response>::~MessageClient() {
(void)Finalize();
}
template <class Request, class Response>
void MessageClient<Request, Response>::Stop() {
running_ = false;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::Finalize() {
if (pid_ == 0) {
return SUCCESS;
}
running_ = false;
if (parallel_send_ && wait_rsp_thread_.joinable()) {
wait_rsp_thread_.join();
}
(void)HService::GetInstance().DestroyQueue(device_id_, req_msg_queue_id_);
(void)HService::GetInstance().DestroyQueue(device_id_, rsp_msg_queue_id_);
req_msg_queue_id_ = UINT32_MAX;
rsp_msg_queue_id_ = UINT32_MAX;
pid_ = 0;
responses_received_.clear();
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::NotifyFinalize() const {
ExchangeService::ControlInfo control_info{};
control_info.timeout = kEnqueueTimeout;
ExchangeService::MsgInfo msg_info{};
msg_info.msg_type = kMsgTypeFinalize;
control_info.msg_info = &msg_info;
uint8_t data = 0U;
size_t data_size = sizeof(data);
GE_CHK_STATUS_RET(HService::GetInstance().Enqueue(device_id_, req_msg_queue_id_, &data, data_size, control_info),
"Failed to enqueue request");
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::CreateMessageQueue(const std::string &name_suffix, uint32_t &request_qid,
uint32_t &response_qid, bool is_client) {
if (is_client) {
DF_CHK_ACL_RET(aclrtSetDevice(device_id_));
}
MemQueueAttr mem_queue_attr{};
mem_queue_attr.depth = kMsgQueueDepth;
mem_queue_attr.work_mode = RT_MQ_MODE_PULL;
mem_queue_attr.is_client = is_client;
const std::string req_msg_queue_name = "request_queue." + name_suffix;
GE_CHK_STATUS_RET(HService::GetInstance().CreateQueue(device_id_, req_msg_queue_name, mem_queue_attr, request_qid),
"[Create][Request] message queue failed, queue name %s.", req_msg_queue_name.c_str());
const std::string rsp_msg_queue_name = "response_queue." + name_suffix;
const auto ret = HService::GetInstance().CreateQueue(device_id_, rsp_msg_queue_name, mem_queue_attr, response_qid);
if (ret != SUCCESS) {
GELOGE(FAILED, "[Create][Reponse] message queue failed, queue name %s.", rsp_msg_queue_name.c_str());
(void)HService::GetInstance().DestroyQueue(device_id_, request_qid);
return FAILED;
}
req_msg_queue_id_ = request_qid;
rsp_msg_queue_id_ = response_qid;
GELOGI("[Create][Message] queue success, request qid:%u [%s], response qid:%u [%s]", request_qid,
req_msg_queue_name.c_str(), response_qid, rsp_msg_queue_name.c_str());
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::Initialize(int32_t pid, const std::function<Status(void)> &get_stat_func,
bool waiting_rsp) {
pid_ = pid;
get_stat_func_ = get_stat_func;
GE_CHECK_NOTNULL(get_stat_func_);
GE_CHK_STATUS_RET_NOLOG(InitMessageQueue());
running_ = true;
if (parallel_send_) {
wait_rsp_thread_ = std::thread(&MessageClient::DequeueMessageThread, this);
}
if (waiting_rsp) {
GE_CHK_STATUS_RET_NOLOG(WaitForProcessInitialized());
}
GELOGI("[Initialize] message client succeeded, pid:%d, device_id:%d", pid_, device_id_);
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::InitMessageQueue() const {
rtMemQueueShareAttr_t req_queue_attr{};
req_queue_attr.read = 1;
GE_CHK_RT_RET(rtMemQueueGrant(device_id_, req_msg_queue_id_, pid_, &req_queue_attr));
rtMemQueueShareAttr_t rsp_queue_attr{};
rsp_queue_attr.write = 1;
GE_CHK_RT_RET(rtMemQueueGrant(device_id_, rsp_msg_queue_id_, pid_, &rsp_queue_attr));
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::DequeueMessage(std::shared_ptr<Response> &response) {
rtMbufPtr_t mbuf = nullptr;
auto ret = HService::GetInstance().DequeueMbuf(device_id_, rsp_msg_queue_id_, &mbuf, kDequeueTimeout);
if (ret == SUCCESS) {
GE_MAKE_GUARD(mbuf, [mbuf]() { (void)rtMbufFree(mbuf); });
void *buffer_addr = nullptr;
uint64_t buffer_size = 0;
GE_CHK_STATUS_RET_NOLOG(RtsApiUtils::MbufGetBufferAddr(mbuf, &buffer_addr));
GE_CHK_STATUS_RET_NOLOG(RtsApiUtils::MbufGetBufferSize(mbuf, buffer_size));
google::protobuf::io::ArrayInputStream stream(buffer_addr, static_cast<int32_t>(buffer_size));
auto rsp = MakeShared<Response>();
GE_CHECK_NOTNULL(rsp);
GE_CHK_BOOL_RET_STATUS(rsp->ParseFromZeroCopyStream(&stream), FAILED, "Failed to parse message from proto");
response = rsp;
}
return ret;
}
template <class Request, class Response>
void MessageClient<Request, Response>::DequeueMessageThread() {
SET_THREAD_NAME(pthread_self(), "ge_dpl_msgc");
while (running_) {
std::shared_ptr<Response> response;
auto ret = DequeueMessage(response);
if (ret == SUCCESS) {
uint64_t message_id = response->message_id();
std::lock_guard<std::mutex> lk(mu_);
responses_received_[message_id] = response;
response_cv_.notify_all();
GELOGD("[Dequeue][Message] success, queue id:%u, message id:%lu", rsp_msg_queue_id_, message_id);
continue;
}
if (ret != RT_ERROR_TO_GE_STATUS(ACL_ERROR_RT_QUEUE_EMPTY)) {
GELOGE(FAILED, "Failed to dequeue message");
return;
}
}
GEEVENT("Dequeue Message thread end");
}
template <class Request, class Response>
Status MessageClient<Request, Response>::WaitForProcessInitialized() {
Response response;
if (parallel_send_) {
GE_CHK_STATUS_RET(WaitResponseWithMessageId(response), "Failed to wait response");
} else {
GE_CHK_STATUS_RET(WaitResponse(response), "Failed to wait response");
}
GE_CHK_BOOL_RET_STATUS(response.error_code() == SUCCESS, FAILED, "[Wait][Response] failed, error_message = %s",
response.error_message().c_str());
GEEVENT("[Wait][Process] initialized succeeded, pid:%d", pid_);
return SUCCESS;
}
template <class Request, class Response>
void MessageClient<Request, Response>::SetMessageId(Request &request) {
request.set_message_id(message_id_++);
}
template <class Request, class Response>
Status MessageClient<Request, Response>::SendRequestWithoutResponse(const Request &request) {
auto req_size = request.ByteSizeLong();
ExchangeService::FillFunc fill_func = [&request](void *buffer, size_t size) {
GE_CHK_BOOL_RET_STATUS(request.SerializeToArray(buffer, static_cast<int32_t>(size)), FAILED,
"SerializeToArray failed, request size:%zu", size);
return SUCCESS;
};
ExchangeService::ControlInfo control_info{};
control_info.timeout = kEnqueueTimeout;
GE_CHK_STATUS_RET(HService::GetInstance().Enqueue(device_id_, req_msg_queue_id_, req_size, fill_func, control_info),
"Failed to enqueue request");
GELOGD("[Enqueue][Request] without response succeeded, queue id:%u", req_msg_queue_id_);
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::SendRequest(const Request &request, Response &response, int64_t timeout) {
if (parallel_send_) {
SetMessageId(const_cast<Request &>(request));
}
auto req_size = request.ByteSizeLong();
ExchangeService::FillFunc fill_func = [&request, &response](void *buffer, size_t size) {
if (request.SerializeToArray(buffer, static_cast<int32_t>(size))) {
return SUCCESS;
}
response.set_error_code(FAILED);
response.set_error_message("SerializeToArray failed");
GELOGE(FAILED, "SerializeToArray failed");
return FAILED;
};
ExchangeService::ControlInfo control_info{};
control_info.timeout = kEnqueueTimeout;
GE_CHK_STATUS_RET(HService::GetInstance().Enqueue(device_id_, req_msg_queue_id_, req_size, fill_func, control_info),
"Failed to enqueue request");
GELOGD("[Enqueue][Request] succeeded, queue id:%u, message id:%lu", req_msg_queue_id_, request.message_id());
if (parallel_send_) {
GE_CHK_STATUS_RET(WaitResponseWithMessageId(response, request.message_id(), timeout),
"Failed to wait response, message id:%lu", request.message_id());
} else {
GE_CHK_STATUS_RET(WaitResponse(response, timeout), "Failed to wait response");
}
GELOGD("[Dequeue][Response] succeeded, queue id:%u, message id:%lu", rsp_msg_queue_id_, response.message_id());
return SUCCESS;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::WaitResponseWithMessageId(Response &response, uint64_t message_id,
int64_t timeout) {
const int64_t rsp_timeout = (timeout == -1) ? kDefaultTimeout : timeout;
std::unique_lock<std::mutex> lk(mu_);
GE_CHK_STATUS_RET(get_stat_func_(), "Process already exit");
response_cv_.wait_for(lk, std::chrono::seconds(rsp_timeout), [this, message_id] {
return (!running_) || (responses_received_.find(message_id) != responses_received_.cend());
});
if (responses_received_.find(message_id) != responses_received_.cend()) {
GELOGD("[Wait][Message] success, queue id:%u, message id:%lu", rsp_msg_queue_id_, message_id);
response = std::move(*responses_received_.at(message_id));
responses_received_.erase(message_id);
return SUCCESS;
}
GELOGE(FAILED, "Wait response timeout, wait time = %ld s", rsp_timeout);
return FAILED;
}
template <class Request, class Response>
Status MessageClient<Request, Response>::WaitResponse(Response &response, int64_t timeout) {
std::shared_ptr<Response> rsp;
const int64_t retry_times = (timeout == -1) ? kDefaultRetryTimes : std::max(timeout / kDequeueTimeoutInSec, 1L);
for (int32_t i = 0; i < retry_times; ++i) {
GE_CHK_STATUS_RET(get_stat_func_(), "Process already exit");
GE_CHK_BOOL_RET_STATUS(running_, FAILED, "Wait response failed as stopped");
auto ret = DequeueMessage(rsp);
if (ret == SUCCESS) {
response = *rsp;
return SUCCESS;
}
GE_CHK_BOOL_RET_STATUS(ret == RT_ERROR_TO_GE_STATUS(ACL_ERROR_RT_QUEUE_EMPTY), FAILED, "[Dequeue][Message] failed");
GELOGD("Response message queue is empty, retry time = %d.", i);
}
GELOGE(FAILED, "Wait response timeout, wait time = %d ms.", retry_times * kDequeueTimeout);
return FAILED;
}
template class MessageClient<deployer::DeployerRequest, deployer::DeployerResponse>;
template class MessageClient<deployer::ExecutorRequest, deployer::ExecutorResponse>;
}