/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include <algorithm>

#include "host/shmem_host_def.h"
#include "shmemi_logger.h"
#include "store_message_packer.h"

namespace shm {
namespace store {
std::vector<uint8_t> SmemMessagePacker::Pack(const SmemMessage &message) noexcept
{
    // size + userDef + mt + keyN + vN
    constexpr uint64_t baseSize = 4U * sizeof(uint64_t) + sizeof(MessageType);
    uint64_t totalSize = baseSize;
    for (auto &key : message.keys) {
        totalSize += (sizeof(uint64_t) + key.size());
    }
    for (auto &value : message.values) {
        totalSize += (sizeof(uint64_t) + value.size());
    }

    std::vector<uint8_t> result;
    result.reserve(totalSize);
    PackValue(result, totalSize);
    PackValue(result, message.userDef);
    PackValue(result, message.mt);

    PackValue(result, message.keys.size());
    for (auto &key : message.keys) {
        PackString(result, key);
    }

    PackValue(result, message.values.size());
    for (auto &value : message.values) {
        PackBytes(result, value);
    }

    return result;
}

bool SmemMessagePacker::Full(const uint8_t* buffer, const uint64_t bufferLen) noexcept
{
    constexpr uint64_t baseSize = 4U * sizeof(uint64_t) + sizeof(MessageType);
    if (bufferLen < baseSize) {
        return false;
    }

    auto totalSize = *reinterpret_cast<const uint64_t *>(buffer);
    return bufferLen >= totalSize;
}

int64_t SmemMessagePacker::MessageSize(const std::vector<uint8_t> &buffer) noexcept
{
    if (buffer.size() < sizeof(uint64_t)) {
        return -1L;
    }

    return *reinterpret_cast<const int64_t *>(buffer.data());
}

int64_t SmemMessagePacker::Unpack(const uint8_t* buffer, const uint64_t bufferLen, SmemMessage &message) noexcept
{
    SHM_CHECK_CONDITION_RET(buffer == nullptr, -1);
    SHM_CHECK_CONDITION_RET(!Full(buffer, bufferLen), -1);

    uint64_t length = 0ULL;
    auto totalSize = *reinterpret_cast<const uint64_t *>(buffer + length);
    length += sizeof(uint64_t);

    message.userDef = *reinterpret_cast<const int64_t *>(buffer + length);
    length += sizeof(int64_t);

    message.mt = *reinterpret_cast<const MessageType *>(buffer + length);
    length += sizeof(MessageType);
    SHM_CHECK_CONDITION_RET(message.mt < MessageType::SET || message.mt > MessageType::INVALID_MSG, -1);

    uint64_t keyCount = 0;
    std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &keyCount);
    SHM_CHECK_CONDITION_RET(keyCount > MAX_KEY_COUNT, -1);

    length += sizeof(uint64_t);
    message.keys.reserve(keyCount);

    for (auto i = 0UL; i < keyCount; i++) {
        uint64_t keySize = 0;
        std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &keySize);
        length += sizeof(uint64_t);

        SHM_CHECK_CONDITION_RET(keySize > MAX_KEY_SIZE || length + keySize > bufferLen, -1);
        message.keys.emplace_back(reinterpret_cast<const char *>(buffer + length), keySize);
        length += keySize;
    }

    uint64_t valueCount = 0;
    std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &valueCount);
    SHM_CHECK_CONDITION_RET(valueCount > MAX_VALUE_COUNT, -1);

    length += sizeof(uint64_t);
    message.values.reserve(valueCount);

    for (auto i = 0UL; i < valueCount; i++) {
        uint64_t valueSize = 0;
        std::copy_n(reinterpret_cast<const uint64_t *>(buffer + length), 1, &valueSize);
        length += sizeof(uint64_t);
        SHM_CHECK_CONDITION_RET(valueSize > MAX_VALUE_SIZE || length + valueSize > bufferLen, -1);

        message.values.emplace_back(buffer + length, buffer + length + valueSize);
        length += valueSize;
    }
    SHM_CHECK_CONDITION_RET(totalSize != length, -1);
    return static_cast<int64_t>(totalSize);
}

void SmemMessagePacker::PackString(std::vector<uint8_t> &dest, const std::string &str) noexcept
{
    PackValue(dest, static_cast<uint64_t>(str.size()));
    if (!str.empty()) {
        dest.insert(dest.end(), str.data(), str.data() + str.size());
    }
}

void SmemMessagePacker::PackBytes(std::vector<uint8_t> &dest, const std::vector<uint8_t> &bytes) noexcept
{
    PackValue(dest, static_cast<uint64_t>(bytes.size()));
    dest.insert(dest.end(), bytes.begin(), bytes.end());
}
}  // shm
}  // store