* Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Description: Zmq Message Transport Protocol.
*/
#include <cstdint>
#include <functional>
#include <limits>
#include <map>
#include "datasystem/common/rpc/zmq/zmq_msg_decoder.h"
#include "datasystem/common/util/raii.h"
#include "datasystem/common/util/status_helper.h"
#include "datasystem/common/util/strings_util.h"
namespace datasystem {
std::mutex ZmqMsgDecoder::allocMux_;
std::map<uint64_t, uint64_t, std::greater<uint64_t>> ZmqMsgDecoder::regAlloc_{};
size_t ZmqMsgDecoder::NumUnRead() const
{
return static_cast<size_t>(bytesReceived_) - pos_;
}
bool ZmqMsgDecoder::Empty() const
{
return NumUnRead() == 0;
}
Status ZmqMsgDecoder::Recv()
{
auto *buf = wa_.get();
if (pos_ > 0) {
auto moveAmount = bytesReceived_ - pos_;
if (moveAmount > 0) {
VLOG(RPC_LOG_LEVEL) << FormatString("Decoder: move %zd bytes from %zu to the beginning", moveAmount, pos_);
auto err = memmove_s(buf, K_WA_SIZE, buf + pos_, moveAmount);
CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(err == 0, K_RUNTIME_ERROR, FormatString("Memmove failed %d", errno));
}
bytesReceived_ = static_cast<ssize_t>(moveAmount);
pos_ = 0;
}
CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(K_WA_SIZE >= bytesReceived_, K_RUNTIME_ERROR,
FormatString("Invalid bytesReceived_ %zu", bytesReceived_));
ssize_t bytesReceived = recv(pSockFd_->GetFd(), buf + bytesReceived_, K_WA_SIZE - bytesReceived_, 0);
if (bytesReceived == -1) {
auto rc = UnixSockFd::ErrnoToStatus(errno, pSockFd_->GetFd());
return rc;
}
if (bytesReceived == 0) {
RETURN_STATUS(StatusCode::K_RPC_CANCELLED, "bytesReceived is 0");
}
bytesReceived_ += bytesReceived;
VLOG(RPC_LOG_LEVEL) << FormatString("Bytes received %zd. Head %zu. Tail %zd", bytesReceived, pos_, bytesReceived_);
return Status::OK();
}
Status ZmqMsgDecoder::TransferFromWA(void *dest, size_t sz, size_t &bytesReceived)
{
bytesReceived = std::min<size_t>(NumUnRead(), sz);
if (bytesReceived > 0) {
VLOG(RPC_LOG_LEVEL) << FormatString("Copy %zu bytes from workarea pos %zu to msg buffer", bytesReceived, pos_);
auto *buf = wa_.get() + pos_;
auto err = memcpy_s(dest, sz, buf, bytesReceived);
CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(err == 0, K_RUNTIME_ERROR,
FormatString("Memcpy %zu bytes failed. Errno %d", bytesReceived, errno));
if (SIZE_MAX - bytesReceived < pos_) {
RETURN_STATUS_LOG_ERROR(K_RUNTIME_ERROR, FormatString("Computation overflow: bytesReceived=%zu , pos_=%zu.",
bytesReceived, pos_));
}
pos_ += bytesReceived;
}
return Status::OK();
}
Status ZmqMsgDecoder::DecodeHdrLen(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::HDR_LEN_READY, K_RUNTIME_ERROR, "Wrong state");
if (Empty()) {
RETURN_IF_NOT_OK(pSockFd_->RecvProtobuf(hdr_));
state = MsgState::MTP_DETECT;
return Status::OK();
}
uint32_t sz;
if (NumUnRead() < sizeof(sz)) {
RETURN_IF_NOT_OK(Recv());
CHECK_FAIL_RETURN_STATUS(NumUnRead() >= sizeof(sz), K_TRY_AGAIN, "Waiting for more network data to arrive");
}
{
uint8_t *arr = wa_.get() + pos_;
google::protobuf::io::ArrayInputStream osWrapper(arr, sizeof(sz), sizeof(sz));
google::protobuf::io::CodedInputStream input(&osWrapper);
CHECK_FAIL_RETURN_STATUS(input.ReadLittleEndian32(&sz), K_RUNTIME_ERROR, "Google read error");
}
state = MsgState::HDR_BODY_READY;
rpcHdrSz_ = sz;
pos_ += sizeof(sz);
return Status::OK();
}
Status ZmqMsgDecoder::DecodeHdrBody(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::HDR_BODY_READY, K_RUNTIME_ERROR, "Wrong state");
if (NumUnRead() < rpcHdrSz_) {
RETURN_IF_NOT_OK(Recv());
CHECK_FAIL_RETURN_STATUS(NumUnRead() >= rpcHdrSz_, K_TRY_AGAIN, "Waiting for more network data to arrive");
}
hdr_.Clear();
uint8_t *arr = wa_.get() + pos_;
bool success = hdr_.ParseFromArray(arr, rpcHdrSz_);
if (success) {
state = MsgState::MTP_DETECT;
pos_ += rpcHdrSz_;
return Status::OK();
}
RETURN_STATUS(StatusCode::K_INVALID, "Failed to parse MultiMsgHdrPb");
}
Status ZmqMsgDecoder::DetectMTP(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::MTP_DETECT, K_RUNTIME_ERROR, "Wrong state");
newFormat_ = hdr_.msg_size_size() == 0;
if (newFormat_) {
VLOG(RPC_LOG_LEVEL) << FormatString("V2 format detected for fd %d", pSockFd_->GetFd());
state = MsgState::FLAGS_READY;
} else {
VLOG(RPC_LOG_LEVEL) << FormatString("V1 format detected for fd %d", pSockFd_->GetFd());
state = MsgState::DOWNLEVEL_CLIENT;
}
return Status::OK();
}
Status ZmqMsgDecoder::V1Client(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::DOWNLEVEL_CLIENT, K_RUNTIME_ERROR, "Wrong state");
v1Frames_.clear();
const int numMsg = hdr_.msg_size_size();
VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %d frames from fd %d using V1 format", numMsg,
pSockFd_->GetFd());
for (auto i = 0; i < hdr_.msg_size_size(); ++i) {
size_t msgReadSoFar = 0;
ZmqMessage msg;
auto sz = hdr_.msg_size(i);
RETURN_IF_NOT_OK(msg.AllocMem(sz));
RETURN_IF_NOT_OK(TransferFromWA(msg.Data(), sz, msgReadSoFar));
if (msgReadSoFar < sz) {
RETURN_IF_NOT_OK(
pSockFd_->Recv(reinterpret_cast<uint8_t *>(msg.Data()) + msgReadSoFar, sz - msgReadSoFar, true));
}
VLOG(RPC_LOG_LEVEL) << "Frame (" << i << ") received. Size " << msg.Size() << " ... " << msg;
v1Frames_.push_back(std::move(msg));
}
state = MsgState::HDR_LEN_READY;
return Status::OK();
}
Status ZmqMsgDecoder::DecodeFlag(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::FLAGS_READY, K_RUNTIME_ERROR, "Wrong state");
inProcess_ = ZmqMessage();
msgSize_ = 0;
if (Empty()) {
RETURN_IF_NOT_OK(Recv());
}
uint8_t *buf = wa_.get();
flag_ = static_cast<MTP_PROTOCOL>(buf[pos_]);
const int DISPLAY_LENGTH = 2;
VLOG(RPC_LOG_LEVEL) << "Flag = %x" << std::hex << std::setfill('0') << std::setw(DISPLAY_LENGTH) << flag_;
if (flag_ & MTP_PROTOCOL::MTP_LONG) {
state = MsgState::EIGHT_BYTE_SIZE_READY;
} else {
state = MsgState::ONE_BYTE_SIZE_READY;
}
pos_ += 1;
return Status::OK();
}
Status ZmqMsgDecoder::DecodeOneByteLength(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::ONE_BYTE_SIZE_READY, K_RUNTIME_ERROR, "Wrong state");
if (Empty()) {
RETURN_IF_NOT_OK(Recv());
}
auto *buf = wa_.get() + pos_;
msgSize_ = static_cast<uint8_t>(buf[0]);
VLOG(RPC_LOG_LEVEL) << FormatString("Message length: %d", msgSize_);
state = MsgState::MESSAGE_READY;
pos_ += 1;
return Status::OK();
}
Status ZmqMsgDecoder::DecodeEightByteLength(MsgState &state)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::EIGHT_BYTE_SIZE_READY, K_RUNTIME_ERROR, "Wrong state");
if (NumUnRead() < K_EIGHT_BYTE) {
RETURN_IF_NOT_OK(Recv());
CHECK_FAIL_RETURN_STATUS(NumUnRead() >= K_EIGHT_BYTE, K_TRY_AGAIN, "Waiting for more network data to arrive");
}
{
auto *buf = wa_.get() + pos_;
google::protobuf::io::ArrayInputStream osWrapper(buf, K_EIGHT_BYTE, K_EIGHT_BYTE);
google::protobuf::io::CodedInputStream input(&osWrapper);
CHECK_FAIL_RETURN_STATUS(input.ReadLittleEndian64(&msgSize_), K_RUNTIME_ERROR, "Google read error");
}
VLOG(RPC_LOG_LEVEL) << FormatString("Message length: %d", msgSize_);
state = MsgState::MESSAGE_READY;
pos_ += K_EIGHT_BYTE;
return Status::OK();
}
void ZmqMsgDecoder::RegisterAllocation(void *dest, uint64_t sz)
{
uint64_t addr = reinterpret_cast<uint64_t>(dest);
std::lock_guard<std::mutex> lock(ZmqMsgDecoder::allocMux_);
ZmqMsgDecoder::regAlloc_.emplace(addr, sz);
}
void ZmqMsgDecoder::DeregisterAllocation(void *dest)
{
uint64_t addr = reinterpret_cast<uint64_t>(dest);
std::lock_guard<std::mutex> lock(ZmqMsgDecoder::allocMux_);
ZmqMsgDecoder::regAlloc_.erase(addr);
}
Status ZmqMsgDecoder::FindRegisteredAlloc(void *dest, uint64_t sz)
{
uint64_t addr = reinterpret_cast<uint64_t>(dest);
Status err = Status(K_OUT_OF_RANGE, "No registered allocation found for the destination addr");
std::lock_guard<std::mutex> lock(ZmqMsgDecoder::allocMux_);
auto it = ZmqMsgDecoder::regAlloc_.lower_bound(addr);
if (it != ZmqMsgDecoder::regAlloc_.end()) {
if (std::numeric_limits<uint64_t>::max() - addr < sz) {
return err;
}
if ((addr < it->first) || ((addr + sz) > (it->first + it->second))) {
return err;
}
VLOG(RPC_LOG_LEVEL) << FormatString("(0x%x, 0x%x) inside (0x%x, 0x%x)", addr, addr + sz, it->first,
it->first + it->second);
return Status::OK();
}
return err;
}
Status ZmqMsgDecoder::ReadMessage(MsgState &state, void *dest, size_t sz)
{
CHECK_FAIL_RETURN_STATUS(state == MsgState::MESSAGE_READY, K_RUNTIME_ERROR, "Wrong state");
auto chgState = [this, &state]() {
state = (flag_ & MTP_MORE) ? MsgState::FLAGS_READY : MsgState::HDR_LEN_READY;
};
if (msgSize_ == 0) {
chgState();
return Status::OK();
}
if (dest != nullptr) {
RETURN_IF_NOT_OK_PRINT_ERROR_MSG(ZmqMsgDecoder::FindRegisteredAlloc(dest, sz),
"The destination is not existing allocated memory.");
CHECK_FAIL_RETURN_STATUS_PRINT_ERROR(
sz >= msgSize_, K_RUNTIME_ERROR,
FormatString("User provided buffer is too small %zu. Expect %zu", sz, msgSize_));
inProcess_.TransferOwnership(dest, sz, nullptr, nullptr);
} else {
VLOG(RPC_LOG_LEVEL) << FormatString("Allocate %zu bytes for incoming message", msgSize_);
RETURN_IF_NOT_OK(inProcess_.AllocMem(msgSize_));
}
size_t msgReadSoFar = 0;
RETURN_IF_NOT_OK(TransferFromWA(inProcess_.Data(), msgSize_, msgReadSoFar));
if (msgReadSoFar < msgSize_) {
RETURN_IF_NOT_OK(pSockFd_->Recv(reinterpret_cast<uint8_t *>(inProcess_.Data()) + msgReadSoFar,
msgSize_ - msgReadSoFar, true));
}
chgState();
return Status::OK();
}
Status ZmqMsgDecoder::Decode(void *dest, size_t sz)
{
if (msgState_ == MsgState::HDR_LEN_READY) {
RETURN_IF_NOT_OK(DecodeHdrLen(msgState_));
}
if (msgState_ == MsgState::HDR_BODY_READY) {
RETURN_IF_NOT_OK(DecodeHdrBody(msgState_));
}
if (msgState_ == MsgState::MTP_DETECT) {
RETURN_IF_NOT_OK(DetectMTP(msgState_));
}
if (msgState_ == MsgState::DOWNLEVEL_CLIENT) {
RETURN_IF_NOT_OK(V1Client(msgState_));
}
if (msgState_ == MsgState::FLAGS_READY) {
RETURN_IF_NOT_OK(DecodeFlag(msgState_));
}
if (msgState_ == MsgState::ONE_BYTE_SIZE_READY) {
RETURN_IF_NOT_OK(DecodeOneByteLength(msgState_));
}
if (msgState_ == MsgState::EIGHT_BYTE_SIZE_READY) {
RETURN_IF_NOT_OK(DecodeEightByteLength(msgState_));
}
if (msgState_ == MsgState::MESSAGE_READY) {
RETURN_IF_NOT_OK(ReadMessage(msgState_, dest, sz));
}
return Status::OK();
}
Status ZmqMsgDecoder::GetMessage(ZmqMessage &outMsg, bool &more)
{
Status rc;
more = false;
do {
rc = Decode();
RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN);
if (rc.GetCode() == K_TRY_AGAIN) {
continue;
}
if (!newFormat_) {
break;
}
more = (flag_ & MTP_MORE) != 0;
outMsg = std::move(inProcess_);
if (flag_ & MTP_PROTOCOL::MTP_DECODER) {
outMsg.SetType(ZmqMessage::ZmqMsgType::DECODER);
}
} while (rc.GetCode() == K_TRY_AGAIN);
return rc;
}
Status ZmqMsgDecoder::ReceiveMsgFramesV1(ZmqMsgFrames &frames)
{
MultiMsgHdrPb hdr;
RETURN_IF_NOT_OK(pSockFd_->RecvProtobuf(hdr));
const int numMsg = hdr.msg_size_size();
VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %d frames from fd %d using V1 format", numMsg,
pSockFd_->GetFd());
for (int i = 0; i < numMsg; ++i) {
ZmqMessage msg;
RETURN_IF_NOT_OK(msg.AllocMem(hdr.msg_size(i)));
RETURN_IF_NOT_OK(pSockFd_->Recv(msg.Data(), msg.Size(), true));
VLOG(RPC_LOG_LEVEL) << "Frame (" << i << ") received. Size " << msg.Size() << " ... " << msg;
frames.push_back(std::move(msg));
}
return Status::OK();
}
Status ZmqMsgDecoder::ReceiveMsgFramesV2(ZmqMsgFrames &frames)
{
bool more = false;
Status rc;
curFrame_ = 0;
do {
ZmqMessage msg;
RETURN_IF_NOT_OK(GetMessage(msg, more));
if (!newFormat_) {
while (!v1Frames_.empty()) {
ZmqMessage v1Msg = std::move(v1Frames_.front());
v1Frames_.pop_front();
frames.push_back(std::move(v1Msg));
}
break;
}
VLOG(RPC_LOG_LEVEL) << "Frame (" << curFrame_ << ") received. Size " << msg.Size() << " ... " << msg;
if (msg.GetType() == ZmqMessage::ZmqMsgType::DECODER) {
PayloadDirectGetRspPb pb;
RETURN_IF_NOT_OK(ParseFromZmqMessage(msg, pb));
RETURN_IF_NOT_OK(ReceivePayloadIntoMemory(reinterpret_cast<void *>(pb.addr()), pb.sz()));
CHECK_FAIL_RETURN_STATUS(msgState_ == MsgState::HDR_LEN_READY, K_RUNTIME_ERROR,
FormatString("Unexpected state %d", static_cast<int>(msgState_)));
more = false;
}
frames.push_back(std::move(msg));
++curFrame_;
} while (more);
return Status::OK();
}
Status ZmqMsgDecoder::ReceivePayloadIntoMemory(void *dest, size_t sz)
{
CHECK_FAIL_RETURN_STATUS(dest != nullptr, K_RUNTIME_ERROR, "Null destination");
VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to receive %zu bytes into user provided memory at 0x%x", sz,
reinterpret_cast<intptr_t>(dest));
auto *ptr = reinterpret_cast<uint8_t *>(dest);
Status rc;
while (sz > 0) {
rc = Decode(ptr, sz);
RETURN_IF_NOT_OK_EXCEPT(rc, K_TRY_AGAIN);
if (rc.GetCode() == K_TRY_AGAIN) {
continue;
}
ptr += msgSize_;
sz -= msgSize_;
VLOG(RPC_LOG_LEVEL) << "Frame (" << curFrame_++ << ") received. Size " << msgSize_;
}
return Status::OK();
}
ZmqMsgDecoder::ZmqMsgDecoder(int fd)
: sockFd_(fd),
curFrame_(0),
msgState_(MsgState::HDR_LEN_READY),
flag_(MTP_PROTOCOL::MTP_NONE),
bytesReceived_(0),
pos_(0),
msgSize_(0),
rpcHdrSz_(0),
newFormat_(true)
{
wa_ = std::make_unique<uint8_t[]>(K_WA_SIZE);
pSockFd_ = &sockFd_;
}
ZmqMsgDecoder::ZmqMsgDecoder(UnixSockFd *sockFdRef)
: pSockFd_(sockFdRef),
curFrame_(0),
msgState_(MsgState::HDR_LEN_READY),
flag_(MTP_PROTOCOL::MTP_NONE),
bytesReceived_(0),
pos_(0),
msgSize_(0),
rpcHdrSz_(0),
newFormat_(true)
{
wa_ = std::make_unique<uint8_t[]>(K_WA_SIZE);
}
ZmqMsgDecoder::~ZmqMsgDecoder()
{
}
Status ZmqMsgEncoder::SendMessage(const ZmqMessage &msg, bool more) const
{
struct {
uint8_t flag_;
char len_[K_EIGHT_BYTE];
} hdr{};
static_assert(sizeof(hdr) == K_EIGHT_BYTE + 1, "Doesn't expect a gap");
auto sz = msg.Size();
hdr.flag_ = more ? MTP_MORE : 0;
if (sz > std::numeric_limits<uint8_t>::max()) {
hdr.flag_ |= MTP_LONG;
google::protobuf::io::ArrayOutputStream osWrapper(hdr.len_, K_EIGHT_BYTE);
google::protobuf::io::CodedOutputStream output(&osWrapper);
output.WriteLittleEndian64(sz);
} else {
hdr.len_[0] = static_cast<char>(sz);
}
auto type = msg.GetType();
if (type == ZmqMessage::ZmqMsgType::DECODER) {
hdr.flag_ |= MTP_DECODER;
}
const int SHORT_LENGTH = 2;
MemView buf(&hdr, (hdr.flag_ & MTP_LONG) ? K_EIGHT_BYTE + 1 : SHORT_LENGTH);
RETURN_IF_NOT_OK(pSockFd_->Send(buf));
if (sz > 0) {
buf = MemView(msg.Data(), sz);
RETURN_IF_NOT_OK(pSockFd_->Send(buf));
}
return Status::OK();
}
Status ZmqMsgEncoder::SendMsgFramesV1(ZmqMsgFrames &que)
{
MultiMsgHdrPb hdr;
auto it = que.begin();
while (it != que.end()) {
hdr.mutable_msg_size()->Add(it->Size());
++it;
}
VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to send %d frames to fd %d using V1 format", hdr.msg_size_size(),
pSockFd_->GetFd());
RETURN_IF_NOT_OK_PRINT_ERROR_MSG(pSockFd_->SendProtobuf(hdr), FormatString("Errno = %d", errno));
int i = 0;
while (!que.empty()) {
auto &msg = que.front();
MemView buf(msg.Data(), msg.Size());
RETURN_IF_NOT_OK(pSockFd_->Send(buf));
VLOG(RPC_LOG_LEVEL) << "Frame (" << i++ << ") sent. Size " << msg.Size() << " ... " << msg;
que.pop_front();
}
return Status::OK();
}
Status ZmqMsgEncoder::SendMsgFramesV2(ZmqMsgFrames &que)
{
VLOG(RPC_LOG_LEVEL) << FormatString("Prepare to send %d frames to fd %d using V2 format", que.size(),
pSockFd_->GetFd());
{
MultiMsgHdrPb hdr;
RETURN_IF_NOT_OK_PRINT_ERROR_MSG(pSockFd_->SendProtobuf(hdr), FormatString("Errno = %d", errno));
}
int i = 0;
bool more = true;
do {
auto msg = std::move(que.front());
que.pop_front();
more = !que.empty();
RETURN_IF_NOT_OK(SendMessage(msg, more));
VLOG(RPC_LOG_LEVEL) << "Frame (" << i++ << ") sent. Size " << msg.Size() << " ... " << msg;
} while (more);
return Status::OK();
}
Status ZmqMsgEncoder::SendMsgFrames(EventType type, ZmqMsgFrames &frames)
{
if (type == V2MTP) {
return SendMsgFramesV2(frames);
} else if (type == V1MTP) {
VLOG(RPC_LOG_LEVEL) << FormatString("Fall back to V1 format for fd %d", pSockFd_->GetFd());
return SendMsgFramesV1(frames);
}
RETURN_STATUS(K_INVALID, FormatString("Unsupported type %d", type));
}
}