* @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* @Description: Spanning Wrapper for DataStream
*/
#ifndef FLINK_TNEL_SPANNINGWRAPPER_H
#define FLINK_TNEL_SPANNINGWRAPPER_H
#include <vector>
#include <memory>
#include "core/utils/ByteBuffer.h"
#include "core/memory/DataInputDeserializer.h"
#include "NonSpanningWrapper.h"
#include "core/utils/utils.h"
#include "core/include/common.h"
namespace omnistream::datastream {
class SpanningWrapper {
public:
ByteBuffer* lengthBuffer_;
SpanningWrapper();
~SpanningWrapper();
inline bool hasFullRecord() const;
inline int getNumGatheredBytes() const;
inline void clear();
inline DataInputView& getInputView();
inline void transferLeftOverTo(NonSpanningWrapper& nonSpanningWrapper);
inline void transferFrom(NonSpanningWrapper& partial, int nextRecordLength);
inline void addNextChunkFromMemoryBuffer(const uint8_t* buffer, int numBytes);
Buffer* GetUnconsumedSegment()
{
LOG("SpanningWrapper GetUnconsumedSegment position: " << lengthBuffer_->position());
if (lengthBuffer_->position() > 0) {
uint8_t* data = reinterpret_cast<uint8_t*>(malloc(lengthBuffer_->position()));
MemorySegment* memorySegment = new MemorySegment(data, lengthBuffer_->position());
memorySegment->put(0, lengthBuffer_->getValue(), 0, lengthBuffer_->position());
::datastream::NetworkBuffer* networkBuffer = new ::datastream::NetworkBuffer(
memorySegment,
lengthBuffer_->position(),
0,
std::make_shared<OriginalNetworkBufferRecycler>(),
ObjectBufferDataType::DATA_BUFFER,
true);
return networkBuffer;
} else if (recordLength_ == -1) {
return nullptr;
} else {
return CopyDataBuffer();
}
}
::datastream::NetworkBuffer* CopyDataBuffer()
{
int leftOverSize = leftOverLimit_ - leftOverStart_;
int unconsumedSize = LENGTH_BYTES + accumulatedRecordBytes_ + leftOverSize;
auto serializer = std::make_shared<DataOutputSerializer>(unconsumedSize);
serializer->writeInt(recordLength_);
serializer->write(buffer_.data(), accumulatedRecordBytes_, 0, accumulatedRecordBytes_);
if (leftOverData_ != nullptr) {
serializer->write(const_cast<uint8_t*>(leftOverData_), 0, leftOverStart_, leftOverSize);
}
uint8_t* data = reinterpret_cast<uint8_t*>(malloc(unconsumedSize));
MemorySegment* memorySegment = new MemorySegment(data, unconsumedSize);
memorySegment->put(0, serializer->getData(), 0, unconsumedSize);
::datastream::NetworkBuffer* networkBuffer = new ::datastream::NetworkBuffer(
memorySegment,
unconsumedSize,
0,
std::make_shared<OriginalNetworkBufferRecycler>(),
ObjectBufferDataType::DATA_BUFFER,
true);
return networkBuffer;
}
private:
std::vector<uint8_t> buffer_;
int recordLength_;
int accumulatedRecordBytes_;
DataInputDeserializer* serializationReadBuffer_;
const uint8_t* leftOverData_;
int leftOverStart_;
int leftOverLimit_;
inline bool isReadingLength() const;
inline void updateLength(int length);
inline int readLength(const uint8_t* buffer, int remaining);
inline void ensureBufferCapacity(int minLength);
inline void copyIntoBuffer(const uint8_t* buffer, int offset, int length);
};
inline bool SpanningWrapper::hasFullRecord() const
{
bool result = recordLength_ >= 0 && accumulatedRecordBytes_ >= recordLength_;
#ifdef DEBUG
if (result) PRINT_HEX(const_cast<uint8_t*>(buffer_.data()), 0, recordLength_);
#endif
return result;
}
inline DataInputView& SpanningWrapper::getInputView()
{
return *serializationReadBuffer_;
}
inline void SpanningWrapper::transferLeftOverTo(NonSpanningWrapper& nonSpanningWrapper)
{
nonSpanningWrapper.clear();
if (leftOverData_ != nullptr) {
nonSpanningWrapper.initializeFromMemoryBuffer(leftOverData_ + leftOverStart_, leftOverLimit_ - leftOverStart_);
}
clear();
}
inline void SpanningWrapper::clear()
{
recordLength_ = -1;
accumulatedRecordBytes_ = 0;
lengthBuffer_->clear();
leftOverData_ = nullptr;
leftOverStart_ = 0;
leftOverLimit_ = 0;
}
inline void SpanningWrapper::transferFrom(NonSpanningWrapper& partial, int nextRecordLength)
{
updateLength(nextRecordLength);
accumulatedRecordBytes_ = partial.copyContentTo(buffer_.data());
partial.clear();
}
inline void SpanningWrapper::addNextChunkFromMemoryBuffer(const uint8_t* buffer, int numBytes)
{
#ifdef DEBUG
ByteBuffer::showInternalInfo(lengthBuffer_);
#endif
int numBytesRead = isReadingLength() ? readLength(buffer, numBytes) : 0;
int offset = numBytesRead;
int remainNumBytes = numBytes - numBytesRead;
if (remainNumBytes == 0) {
return;
}
int toCopy = std::min(recordLength_ - accumulatedRecordBytes_, remainNumBytes);
if (toCopy > 0) {
copyIntoBuffer(buffer, offset, toCopy);
}
if (remainNumBytes > toCopy) {
leftOverData_ = buffer;
leftOverStart_ = offset + toCopy;
leftOverLimit_ = numBytes;
}
}
inline bool SpanningWrapper::isReadingLength() const
{
return lengthBuffer_->position() > 0;
}
inline void SpanningWrapper::updateLength(int length)
{
lengthBuffer_->clear();
recordLength_ = length;
ensureBufferCapacity(length);
}
inline void SpanningWrapper::ensureBufferCapacity(int minLength)
{
if (static_cast<size_t>(minLength) > buffer_.capacity()) {
int newCapacity_ = std::max(minLength, static_cast<int>(buffer_.capacity() * 2));
buffer_.reserve(newCapacity_);
}
}
inline int SpanningWrapper::readLength(const uint8_t* buffer, int remaining)
{
int bytesToRead = std::min(lengthBuffer_->remaining(), remaining);
lengthBuffer_->putBytes(buffer, bytesToRead);
if (!lengthBuffer_->hasRemaining()) {
updateLength(lengthBuffer_->getIntBigEndian(0));
}
return bytesToRead;
}
inline void SpanningWrapper::copyIntoBuffer(const uint8_t* buffer, int offset, int length)
{
auto ret = memcpy_s(buffer_.data() + accumulatedRecordBytes_, length, buffer + offset, length);
if (ret != EOK) {
throw std::runtime_error("memcpy_s failed");
}
accumulatedRecordBytes_ += length;
if (hasFullRecord()) {
serializationReadBuffer_->setBuffer(buffer_.data(), accumulatedRecordBytes_, 0, recordLength_);
}
}
inline int SpanningWrapper::getNumGatheredBytes() const
{
return accumulatedRecordBytes_ + (recordLength_ >= 0 ? LENGTH_BYTES : lengthBuffer_->position());
}
}
#endif