* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#include "protocol.h"
#include <cstdint>
#include "record_defs.h"
#include "record_format.h"
#include "utility/cpp_future.h"
#include "utility/log.h"
#include "utility/serializer.h"
namespace Sanitizer {
class MemCheckProtocol::Extractor {
public:
Extractor() = default;
inline Extractor(Extractor const &rhs);
~Extractor() = default;
inline Extractor &operator=(Extractor rhs);
inline uint64_t Size(void) const;
template<typename T>
inline bool Read(T &val);
inline bool Read(uint64_t size, std::string &buffer);
inline void Feed(const std::string &msg);
inline void Swap(Extractor &rhs);
private:
static constexpr uint64_t BYTES_DROP_THRESHOLD = 1024UL;
static constexpr uint64_t MAX_STREAM_LEN = 100UL * 1024UL * 1024UL * 1024UL;
inline void DropUsedBytes(void);
private:
std::string bytes_;
std::string::size_type offset_{0UL};
std::mutex mutex_;
};
MemCheckProtocol::Extractor::Extractor(Extractor const &rhs)
: bytes_(rhs.bytes_), offset_{rhs.offset_} { }
MemCheckProtocol::Extractor &MemCheckProtocol::Extractor::operator=(Extractor rhs)
{
Swap(rhs);
return *this;
}
uint64_t MemCheckProtocol::Extractor::Size(void) const
{
return bytes_.size() - offset_;
}
template <typename T>
bool MemCheckProtocol::Extractor::Read(T &val)
{
std::string buffer;
if (!Read(sizeof(T), buffer)) {
return false;
}
return Deserialize<T>(buffer, val);
}
bool MemCheckProtocol::Extractor::Read(uint64_t size, std::string &buffer)
{
std::lock_guard<std::mutex> guard(mutex_);
if (offset_ + size > bytes_.size()) {
return false;
}
buffer = bytes_.substr(offset_, size);
offset_ += size;
if (offset_ > BYTES_DROP_THRESHOLD) {
DropUsedBytes();
}
return true;
}
void MemCheckProtocol::Extractor::Feed(const std::string &msg)
{
std::lock_guard<std::mutex> guard(mutex_);
if (bytes_.size() + msg.size() < MAX_STREAM_LEN) {
bytes_ += msg;
}
}
void MemCheckProtocol::Extractor::Swap(Extractor &rhs)
{
std::lock_guard<std::mutex> guard(mutex_);
std::swap(bytes_, rhs.bytes_);
std::swap(offset_, rhs.offset_);
}
void MemCheckProtocol::Extractor::DropUsedBytes(void)
{
offset_ = std::min(offset_, bytes_.size());
bytes_ = bytes_.substr(offset_);
offset_ = 0UL;
}
MemCheckProtocol::MemCheckProtocol()
{
extractor_ = std::make_shared<Extractor>();
}
template <typename TriviallyCopyableT>
Packet MemCheckProtocol::GetTriviallyCopyable(void)
{
TriviallyCopyableT value;
if (!extractor_->Read(value)) {
return Packet();
}
return Packet(value);
}
template <typename BinaryDataT>
Packet MemCheckProtocol::GetBinaryDataPacket(void)
{
std::string buffer;
return GetBinaryData(buffer) ? Packet(BinaryDataT{buffer}) : Packet();
}
bool MemCheckProtocol::GetPacketHead(PacketHead &head) const
{
return extractor_->Read(head);
}
bool MemCheckProtocol::GetBinaryData(std::string &data)
{
thread_local static uint64_t size{};
if (size == 0UL) {
if (!extractor_->Read(size)) {
return false;
}
}
if (extractor_->Size() < size) {
return false;
}
if (!extractor_->Read(size, data)) {
return false;
}
size = 0UL;
return true;
}
Packet MemCheckProtocol::GetPayLoad(PacketHead head)
{
switch (head.type) {
case PacketType::DEVICE_SUMMARY:
return GetTriviallyCopyable<DeviceInfoSummary>();
case PacketType::KERNEL_SUMMARY:
return GetTriviallyCopyable<KernelSummary>();
case PacketType::MEMORY_RECORD:
return GetTriviallyCopyable<HostMemRecord>();
case PacketType::KERNEL_RECORD:
return GetBinaryDataPacket<Packet::KernelRecord>();
case PacketType::IPC_RECORD:
return GetTriviallyCopyable<IPCMemRecord>();
case PacketType::SANITIZER_RECORD:
return GetTriviallyCopyable<SanitizerRecord>();
case PacketType::MEM_REGION_PERMISSION:
return GetTriviallyCopyable<MemRegionPermissionDesc>();
case PacketType::KERNEL_BINARY:
return GetBinaryDataPacket<Packet::KernelBinary>();
case PacketType::LOG_STRING:
return GetBinaryDataPacket<Packet::LogString>();
case PacketType::GM_ADDR_OUT_OF_BOUND_RECORD:
return GetTriviallyCopyable<GMAddrOutOfBoundRecord>();
case PacketType::INVALID:
default:
return Packet{};
}
}
void MemCheckProtocol::Feed(std::string const &msg)
{
if (msg.size() != 0U) {
extractor_->Feed(msg);
}
}
Packet MemCheckProtocol::GetPacket(void)
{
thread_local static PacketHead head{PacketType::INVALID};
if (head.type == PacketType::INVALID) {
if (!GetPacketHead(head)) {
return Packet{};
}
}
Packet packet = GetPayLoad(head);
if (packet.GetType() != PacketType::INVALID) {
head.type = PacketType::INVALID;
}
return packet;
}
}