* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <algorithm>
#include "host/shmem_host_def.h"
#include "shmemi_logger.h"
#include "store_message_packer.h"
namespace shm {
namespace store {
std::vector<uint8_t> SmemMessagePacker::Pack(const SmemMessage &message) noexcept
{
constexpr uint64_t baseSize = 4U * sizeof(uint64_t) + sizeof(MessageType);
uint64_t totalSize = baseSize;
for (auto &key : message.keys) {
totalSize += (sizeof(uint64_t) + key.size());
}
for (auto &value : message.values) {
totalSize += (sizeof(uint64_t) + value.size());
}
std::vector<uint8_t> result;
result.reserve(totalSize);
PackValue(result, totalSize);
PackValue(result, message.userDef);
PackValue(result, message.mt);
PackValue(result, message.keys.size());
for (auto &key : message.keys) {
PackString(result, key);
}
PackValue(result, message.values.size());
for (auto &value : message.values) {
PackBytes(result, value);
}
return result;
}
bool SmemMessagePacker::Full(const uint8_t* buffer, const uint64_t bufferLen) noexcept
{
constexpr uint64_t baseSize = 4U * sizeof(uint64_t) + sizeof(MessageType);
if (bufferLen < baseSize) {
return false;
}
auto totalSize = *reinterpret_cast<const uint64_t *>(buffer);
return bufferLen >= totalSize;
}
int64_t SmemMessagePacker::MessageSize(const std::vector<uint8_t> &buffer) noexcept
{
if (buffer.size() < sizeof(uint64_t)) {
return -1L;
}
return *reinterpret_cast<const int64_t *>(buffer.data());
}
int64_t SmemMessagePacker::Unpack(const uint8_t* buffer, const uint64_t bufferLen, SmemMessage &message) noexcept
{
SHM_CHECK_CONDITION_RET(buffer == nullptr, -1);
SHM_CHECK_CONDITION_RET(!Full(buffer, bufferLen), -1);
uint64_t length = 0ULL;
auto totalSize = *reinterpret_cast<const uint64_t *>(buffer + length);
length += sizeof(uint64_t);
message.userDef = *reinterpret_cast<const int64_t *>(buffer + length);
length += sizeof(int64_t);
message.mt = *reinterpret_cast<const MessageType *>(buffer + length);
length += sizeof(MessageType);
SHM_CHECK_CONDITION_RET(message.mt < MessageType::SET || message.mt > MessageType::INVALID_MSG, -1);
uint64_t keyCount = 0;
std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &keyCount);
SHM_CHECK_CONDITION_RET(keyCount > MAX_KEY_COUNT, -1);
length += sizeof(uint64_t);
message.keys.reserve(keyCount);
for (auto i = 0UL; i < keyCount; i++) {
uint64_t keySize = 0;
std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &keySize);
length += sizeof(uint64_t);
SHM_CHECK_CONDITION_RET(keySize > MAX_KEY_SIZE || length + keySize > bufferLen, -1);
message.keys.emplace_back(reinterpret_cast<const char *>(buffer + length), keySize);
length += keySize;
}
uint64_t valueCount = 0;
std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &valueCount);
SHM_CHECK_CONDITION_RET(valueCount > MAX_VALUE_COUNT, -1);
length += sizeof(uint64_t);
message.values.reserve(valueCount);
for (auto i = 0UL; i < valueCount; i++) {
uint64_t valueSize = 0;
std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &valueSize);
length += sizeof(uint64_t);
SHM_CHECK_CONDITION_RET(valueSize > MAX_VALUE_SIZE || length + valueSize > bufferLen, -1);
message.values.emplace_back(buffer + length, buffer + length + valueSize);
length += valueSize;
}
SHM_CHECK_CONDITION_RET(totalSize != length, -1);
return static_cast<int64_t>(totalSize);
}
void SmemMessagePacker::PackString(std::vector<uint8_t> &dest, const std::string &str) noexcept
{
PackValue(dest, static_cast<uint64_t>(str.size()));
if (!str.empty()) {
dest.insert(dest.end(), str.data(), str.data() + str.size());
}
}
void SmemMessagePacker::PackBytes(std::vector<uint8_t> &dest, const std::vector<uint8_t> &bytes) noexcept
{
PackValue(dest, static_cast<uint64_t>(bytes.size()));
dest.insert(dest.end(), bytes.begin(), bytes.end());
}
}
}