* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIR_RUNTIME_HETEROGENEOUS_COMMON_MESSAGE_HANDLE_MESSAGE_CLIENT_H
#define AIR_RUNTIME_HETEROGENEOUS_COMMON_MESSAGE_HANDLE_MESSAGE_CLIENT_H
#include <mutex>
#include <thread>
#include <condition_variable>
#include "proto/deployer.pb.h"
#include "common/config/device_debug_config.h"
#include "ge/ge_api_error_codes.h"
namespace ge {
template <class Request, class Response>
class MessageClient {
public:
explicit MessageClient(int32_t device_id, bool parallel_send = false);
virtual ~MessageClient();
Status CreateMessageQueue(const std::string &name_suffix, uint32_t &request_qid, uint32_t &response_qid,
bool is_client = false);
Status Initialize(int32_t pid, const std::function<Status(void)> &get_stat_func, bool waiting_rsp = true);
Status NotifyFinalize() const;
void Stop();
Status Finalize();
virtual Status SendRequestWithoutResponse(const Request &request);
virtual Status SendRequest(const Request &request, Response &response, int64_t timeout = -1);
protected:
virtual Status WaitForProcessInitialized();
virtual Status WaitResponseWithMessageId(Response &response, uint64_t message_id = 0UL, int64_t timeout = -1);
virtual Status WaitResponse(Response &response, int64_t timeout = -1);
virtual void DequeueMessageThread();
virtual Status DequeueMessage(std::shared_ptr<Response> &response);
private:
Status InitMessageQueue() const;
void SetMessageId(Request &request);
volatile bool running_ = false;
int32_t pid_ = -1;
int32_t device_id_ = -1;
bool parallel_send_ = false;
std::mutex mu_;
std::condition_variable response_cv_;
uint32_t req_msg_queue_id_ = UINT32_MAX;
uint32_t rsp_msg_queue_id_ = UINT32_MAX;
std::function<Status(void)> get_stat_func_;
std::thread wait_rsp_thread_;
std::map<uint64_t, std::shared_ptr<Response>> responses_received_;
std::atomic<uint64_t> message_id_{0UL};
};
using DeployerMessageClient = MessageClient<deployer::DeployerRequest, deployer::DeployerResponse>;
using ExecutorMessageClient = MessageClient<deployer::ExecutorRequest, deployer::ExecutorResponse>;
}
#endif