* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: RPC generic server streaming api.
* Including stream RPC and non-blocking unary RPC.
* In stream RPC, we have three combinations of reader and writer streaming mode for client and server, respectively.
* In non-blocking unary RPC, we have ServerUnaryWriterReader.
*/
#ifndef DATASYSTEM_COMMON_RPC_SERVER_STREAM_BASE_H
#define DATASYSTEM_COMMON_RPC_SERVER_STREAM_BASE_H
#include <variant>
#include "datasystem/common/rpc/rpc_message.h"
#include "datasystem/common/rpc/zmq/zmq_server_stream_base.h"
#include "datasystem/common/log/log_helper.h"
namespace datasystem {
template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
* Only server side is streaming.
* @tparam W Stream RPC mode, WritePb type.
*/
template <typename W>
class ServerWriter {
public:
explicit ServerWriter(std::unique_ptr<ServerWriterImpl<W>> &&impl) : pimpl_(std::move(impl))
{
}
~ServerWriter() = default;
Status SendStatus(const Status &rc)
{
return std::visit([&rc](auto &pimpl) { return pimpl->SendStatus(rc); }, pimpl_);
}
template <typename R>
Status ReadPb(R &pb)
{
return std::visit(
overloaded{ [&pb](std::unique_ptr<ServerWriterImpl<W>> &pimpl) { return pimpl->ReadPb(pb); },
[](auto &pimpl) { return Status(K_RUNTIME_ERROR, "Unexpected ReadPb function call."); }
},
pimpl_);
}
Status Write(const W &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Write(pb); }, pimpl_);
}
Status Finish()
{
return std::visit([](auto &pimpl) { return pimpl->Finish(); }, pimpl_);
}
Status SendPayload(std::vector<RpcMessage> &buffer)
{
return std::visit([&buffer](auto &pimpl) { return pimpl->SendPayload(buffer); }, pimpl_);
}
Status SendPayload(const std::vector<MemView> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->SendPayload(payload); }, pimpl_);
}
Status ReceivePayload(std::vector<RpcMessage> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->ReceivePayload(payload); }, pimpl_);
}
private:
std::variant<std::unique_ptr<ServerWriterImpl<W>>> pimpl_;
};
* Only the client side is streaming.
*/
template <typename R>
class ServerReader {
public:
explicit ServerReader(std::unique_ptr<ServerReaderImpl<R>> &&impl) : pimpl_(std::move(impl))
{
}
~ServerReader() = default;
Status SendStatus(const Status &rc)
{
return std::visit(
overloaded{ [&rc](std::unique_ptr<ServerReaderImpl<R>> &pimpl) { return pimpl->SendStatus(rc); },
[](auto &pimpl) { return Status(K_RUNTIME_ERROR, "Unexpected SendStatus function call."); }
},
pimpl_);
}
Status Read(R &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Read(pb); }, pimpl_);
}
Status ReceivePayload(std::vector<RpcMessage> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->ReceivePayload(payload); }, pimpl_);
}
template <typename W>
Status WritePb(const W &pb)
{
return std::visit(
overloaded{ [&pb](std::unique_ptr<ServerReaderImpl<R>> &pimpl) { return pimpl->WritePb(pb); },
[](auto &pimpl) { return Status(K_RUNTIME_ERROR, "Unexpected WritePb function call."); }
},
pimpl_);
}
Status SendPayload(std::vector<RpcMessage> &buffer)
{
return std::visit(
overloaded{ [&buffer](std::unique_ptr<ServerReaderImpl<R>> &pimpl) { return pimpl->SendPayload(buffer); },
[](auto &pimpl) { return Status(K_RUNTIME_ERROR, "Unexpected SendPayload function call."); }
},
pimpl_);
}
Status SendPayload(const std::vector<MemView> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->SendPayload(payload); }, pimpl_);
}
Status Finish()
{
return std::visit([](auto &pimpl) { return pimpl->Finish(); }, pimpl_);
}
private:
std::variant<std::unique_ptr<ServerReaderImpl<R>>> pimpl_;
};
* Both sides are streaming.
* @tparam W Stream RPC mode, WritePb type.
* @tparam R Stream RPC mode, ReadPb type.
*/
template <typename W, typename R>
class ServerWriterReader {
public:
explicit ServerWriterReader(std::unique_ptr<ServerWriterReaderImpl<W, R>> &&impl) : pimpl_(std::move(impl))
{
}
~ServerWriterReader() = default;
Status SendStatus(const Status &rc)
{
return std::visit([&rc](auto &pimpl) { return pimpl->SendStatus(rc); }, pimpl_);
}
Status Read(R &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Read(pb); }, pimpl_);
}
Status Write(const W &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Write(pb); }, pimpl_);
}
Status Finish()
{
return std::visit([](auto &pimpl) { return pimpl->Finish(); }, pimpl_);
}
Status SendPayload(std::vector<RpcMessage> &buffer)
{
return std::visit([&buffer](auto &pimpl) { return pimpl->SendPayload(buffer); }, pimpl_);
}
Status SendPayload(const std::vector<MemView> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->SendPayload(payload); }, pimpl_);
}
Status ReceivePayload(std::vector<RpcMessage> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->ReceivePayload(payload); }, pimpl_);
}
private:
std::variant<std::unique_ptr<ServerWriterReaderImpl<W, R>>> pimpl_;
};
template <typename W, typename R>
class ServerUnaryWriterReader {
public:
explicit ServerUnaryWriterReader(std::unique_ptr<ServerUnaryWriterReaderImpl<W, R>> &&impl)
: pimpl_(std::move(impl))
{
}
virtual ~ServerUnaryWriterReader() = default;
virtual Status SendStatus(const Status &rc)
{
return std::visit([&rc](auto &pimpl) { return pimpl->SendStatus(rc); }, pimpl_);
}
virtual Status Read(R &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Read(pb); }, pimpl_);
}
virtual Status Write(const W &pb)
{
return std::visit([&pb](auto &pimpl) { return pimpl->Write(pb); }, pimpl_);
}
virtual Status Finish()
{
return std::visit([](auto &pimpl) { return pimpl->Finish(); }, pimpl_);
}
virtual Status ReceivePayload(std::vector<RpcMessage> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->ReceivePayload(payload); }, pimpl_);
}
virtual Status SendAndTagPayload(std::vector<datasystem::RpcMessage> &buffer, bool tagPayloadFrame)
{
return std::visit(
[&buffer, &tagPayloadFrame](auto &pimpl) { return pimpl->SendAndTagPayload(buffer, tagPayloadFrame); },
pimpl_);
}
virtual Status SendPayload(std::vector<datasystem::RpcMessage> &buffer)
{
return std::visit([&buffer](auto &pimpl) { return pimpl->SendPayload(buffer); }, pimpl_);
}
virtual Status SendAndTagPayload(const std::vector<MemView> &payload, bool tagPayloadFrame)
{
return std::visit(
[&payload, &tagPayloadFrame](auto &pimpl) { return pimpl->SendAndTagPayload(payload, tagPayloadFrame); },
pimpl_);
}
virtual Status SendPayload(const std::vector<MemView> &payload)
{
return std::visit([&payload](auto &pimpl) { return pimpl->SendPayload(payload); }, pimpl_);
}
virtual Status GetOutMsg(ZmqMsgFrames &outMsg)
{
return std::visit([&outMsg](auto &pimpl) { return pimpl->GetOutMsg(outMsg); }, pimpl_);
}
virtual bool EnableMsgQ()
{
return std::visit([](auto &pimpl) { return pimpl->EnableMsgQ(); }, pimpl_);
}
virtual void SetRequestInProgress()
{
return std::visit([](auto &pimpl) { return pimpl->SetRequestInProgress(); }, pimpl_);
}
virtual void SetRequestComplete()
{
return std::visit([](auto &pimpl) { return pimpl->SetRequestComplete(); }, pimpl_);
}
private:
std::variant<std::unique_ptr<ServerUnaryWriterReaderImpl<W, R>>> pimpl_;
};
template <typename W, typename R>
class LocalServerUnaryWriterReader : public ServerUnaryWriterReader<W, R> {
public:
explicit LocalServerUnaryWriterReader(R &pb, std::promise<std::pair<W, Status>> promise)
: ServerUnaryWriterReader<W, R>(std::unique_ptr<ServerUnaryWriterReaderImpl<W, R>>(nullptr))
{
pb_ = std::move(pb);
promise_ = std::move(promise);
writeOnce_ = false;
readOnce_ = false;
}
~LocalServerUnaryWriterReader() override = default;
* @brief Read message into proto buffer.
* @param[out] pb Buffer to read into.
* @return Status of the call.
*/
Status Read(R &pb) override
{
bool expected = false;
if (readOnce_.compare_exchange_strong(expected, true)) {
VLOG(RPC_LOG_LEVEL) << "Server uses unary socket reading" << std::endl;
pb = std::move(pb_);
} else {
RETURN_STATUS(StatusCode::K_RUNTIME_ERROR,
"The Read method of LocalServerUnaryWriterReader is only supposed to be used once!");
}
return Status::OK();
}
* @brief Write message into proto buffer.
* @param[in] pb Buffer to write into.
* @return Status of the call.
*/
Status Write(const W &pb) override
{
bool expected = false;
if (writeOnce_.compare_exchange_strong(expected, true)) {
VLOG(RPC_LOG_LEVEL) << "Server uses unary socket sending rc " << Status::OK() << " message "
<< LogHelper::IgnoreSensitive(pb) << " back to client" << std::endl;
promise_.set_value(std::make_pair(pb, Status::OK()));
} else {
RETURN_STATUS(StatusCode::K_RUNTIME_ERROR,
"The Write method of LocalServerUnaryWriterReader is only supposed to be used once!");
}
return Status::OK();
}
* @brief Send a status through the connection.
* @param[in] rc the status to be sent.
* @return The status of this send action
*/
Status SendStatus(const Status &rc) override
{
bool expected = false;
if (writeOnce_.compare_exchange_strong(expected, true)) {
VLOG(RPC_LOG_LEVEL) << "Server uses unary socket sending rc " << Status::OK();
promise_.set_value(std::make_pair(W(), rc));
} else {
RETURN_STATUS(StatusCode::K_RUNTIME_ERROR,
"The SendStatus method of LocalServerUnaryWriterReader is only supposed to be used once!");
}
return Status::OK();
}
Status SendAndTagPayload(std::vector<datasystem::RpcMessage> &buffer, bool tagPayloadFrame)
{
(void)buffer;
(void)tagPayloadFrame;
return {StatusCode::K_INVALID, "LocalServerUnaryWriterReader doesn't support SendAndTagPayload()!"};
}
Status SendPayload(std::vector<RpcMessage> &buffer) override
{
payloads_ = std::move(buffer);
return Status::OK();
}
Status SendAndTagPayload(const std::vector<MemView> &payload, bool tagPayloadFrame)
{
(void)payload;
(void)tagPayloadFrame;
return {StatusCode::K_INVALID, "LocalServerUnaryWriterReader doesn't support SendAndTagPayload()!"};
}
Status SendPayload(const std::vector<MemView> &payload) override
{
(void)payload;
return {StatusCode::K_INVALID, "LocalServerUnaryWriterReader doesn't support SendPayload()!"};
}
Status ReceivePayload(std::vector<RpcMessage> &payload)
{
payload = std::move(payloads_);
return Status::OK();
}
Status Finish() override
{
return {StatusCode::K_INVALID, "LocalServerUnaryWriterReader doesn't support Finish()!"};
}
void SetRequestInProgress() override
{
return;
}
void SetRequestComplete() override
{
return;
}
bool EnableMsgQ() override
{
return false;
}
private:
R pb_;
std::promise<std::pair<W, Status>> promise_;
std::atomic<bool> writeOnce_;
std::atomic<bool> readOnce_;
std::vector<RpcMessage> payloads_;
};
}
#endif