* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "RemoteInputChannel.h"
#include "table/utils/VectorBatchDeserializationUtils.h"
#include "common.h"
#include "runtime/buffer/NetworkBuffer.h"
#include "runtime/io/checkpointing/CheckpointBarrierHandler.h"
#include <buffer/ReadOnlySlicedNetworkBuffer.h>
namespace omnistream {
RemoteInputChannel::RemoteInputChannel(std::shared_ptr<SingleInputGate> inputGate, int channelIndex,
ResultPartitionIDPOD partitionId,
std::shared_ptr<ResultPartitionManager> partitionManager,
int initialBackoff, int maxBackoff, int networkBuffersPerChannel,
std::shared_ptr<Counter> numBytesIn,
std::shared_ptr<Counter> numBuffersIn,
std::shared_ptr<ChannelStateWriter> stateWriter) :
LocalInputChannel(inputGate, channelIndex, partitionId, partitionManager, initialBackoff, maxBackoff,
numBytesIn, numBuffersIn, stateWriter),
initialCredit(networkBuffersPerChannel)
{
}
void RemoteInputChannel::requestSubpartition(int subpartitionIndex)
{
}
void RemoteInputChannel::notifyRemoteDataAvailableForVectorBatch(long bufferAddress, int bufferLength,
int sequenceNumber)
{
if (bufferAddress == -1) {
int eventType = bufferLength;
LOG("remote got an event data:::: event type: " << eventType)
INFO_RELEASE("remote got an event data:::: event type: " << eventType)
auto eventData = new VectorBatchBuffer(eventType);
std::lock_guard<std::recursive_mutex> lock(queueMutex);
if (eventData != nullptr) {
this->dataQueue.push(eventData);
}
} else {
uint8_t* buffer = reinterpret_cast<uint8_t*>(bufferAddress);
std::shared_ptr<ObjectSegment> objectSegment = this->DoDataDeserializationResult(buffer, bufferLength);
auto vectorBatchBuffer = new VectorBatchBuffer(objectSegment);
if (vectorBatchBuffer != nullptr) {
vectorBatchBuffer->SetSize(objectSegment->getSize());
std::lock_guard<std::recursive_mutex> lock(queueMutex);
this->dataQueue.push(vectorBatchBuffer);
LOG("remote got an buffer "<< vectorBatchBuffer->ToDebugString(true));
}
}
this->notifyDataAvailable();
}
std::optional<BufferAndAvailability> RemoteInputChannel::getNextBuffer()
{
std::lock_guard<std::recursive_mutex> lock(queueMutex);
if (this->dataQueue.size() == 0) {
return std::nullopt;
}
auto buffer = this->dataQueue.front();
this->dataQueue.pop();
ObjectBufferDataType dataType = ObjectBufferDataType::NONE;
int backlogSize = static_cast<int>(this->dataQueue.size());
if (backlogSize > 0) {
dataType = ObjectBufferDataType::DATA_BUFFER;
}
return BufferAndAvailability{buffer, dataType, backlogSize, expectSequenceNumber++};
}
std::shared_ptr<ObjectSegment> RemoteInputChannel::DoDataDeserializationResult(uint8_t*& buffer, int bufferLength)
{
LOG("----DoDataDeserializationResult start 1:: " << buffer << " bufferLength:: " <<
bufferLength)
int32_t elementNum;
memcpy_s(&elementNum, sizeof(int32_t), buffer, sizeof(int32_t));
buffer += sizeof(int32_t);
std::shared_ptr<ObjectSegment> objectSegment = std::make_shared<ObjectSegment>(elementNum);
LOG("----DoDataDeserializationResult start 2:: " << buffer << " bufferLength:: " <<
bufferLength)
for (int32_t i = 0; i < elementNum; i++) {
int8_t dataType;
memcpy_s(&dataType, sizeof(int8_t), buffer, sizeof(int8_t));
buffer += sizeof(int8_t);
LOG("----DoDataDeserializationResult start 3:: " << (buffer) << " bufferLength:: " <<
bufferLength)
StreamElementTag tagType = static_cast<StreamElementTag>(dataType);
switch (tagType) {
case StreamElementTag::TAG_WATERMARK: {
long timestamp = VectorBatchDeserializationUtils::derializeWatermark(buffer);
LOG("RemoteInputChannel::DoDataDeserializationResult:: deserialize watermark :: "<< timestamp)
Watermark* watermark = new Watermark(timestamp);
objectSegment->putObject(i, watermark);
break;
}
case StreamElementTag::VECTOR_BATCH: {
VectorBatch* vb = VectorBatchDeserializationUtils::deserializeVectorBatch(buffer);
StreamRecord* streamRecord = new StreamRecord(vb);
objectSegment->putObject(i, streamRecord);
break;
}
default:
break;
}
}
return objectSegment;
}
void RemoteInputChannel::notifyRemoteDataAvailableForNetworkBuffer(long bufferAddress, int bufferLength,
int readIndex, int sequenceNumber,
std::shared_ptr<OriginalNetworkBufferRecycler>
originalNetworkBufferRecycler, bool isBuffer,
int bufferType)
{
LOG("notifyRemoteDataAvailableForDataStream bufferAddress: " << bufferAddress
<< " bufferLength: " << bufferLength << " sequenceNumber: " << sequenceNumber);
MemorySegment *memorySegment = new MemorySegment(
reinterpret_cast<uint8_t*>(bufferAddress), bufferLength, this);
datastream::NetworkBuffer *networkBuffer = new datastream::NetworkBuffer(
memorySegment, bufferLength, readIndex, originalNetworkBufferRecycler, bufferType, true);
datastream::ReadOnlySlicedNetworkBuffer* readOnlyBuffer =
new datastream::ReadOnlySlicedNetworkBuffer(networkBuffer, readIndex, bufferLength);
std::unique_lock<std::recursive_mutex> lock(queueMutex);
bool wasEmpty = this->dataQueue.empty();
if (readOnlyBuffer != nullptr) {
this->dataQueue.push(readOnlyBuffer);
}
lock.unlock();
if (wasEmpty) {
this->notifyDataAvailable();
}
}
void RemoteInputChannel::SetRemoteDataFetcherBridge(
std::shared_ptr<RemoteDataFetcherBridge> remoteDataFetcherBridge)
{
this->remoteDataFetcherBridge = remoteDataFetcherBridge;
}
void RemoteInputChannel::resumeConsumption()
{
if (!forwardResumeToJava_) {
return;
}
if (this->remoteDataFetcherBridge == nullptr) {
LOG("RemoteInputChannel::resumeConsumption: remoteDataFetcherBridge is null");
return;
}
int gateIndex = this->getChannelInfo().getGateIdx();
int channelIndex = this->getChannelInfo().getInputChannelIdx();
this->remoteDataFetcherBridge->InvokeJavaRemoteDataFetcherResumeConsumption(gateIndex, channelIndex);
}
void RemoteInputChannel::CheckpointStarted(const CheckpointBarrier &barrier)
{
std::lock_guard<std::recursive_mutex> lock(queueMutex);
if (barrier.GetId() < lastBarrierId_) {
LOG("Barrier id is too small");
return;
} else if (barrier.GetId() > lastBarrierId_) {
ResetLastBarrier();
}
channelStatePersister->StartPersisting(barrier.GetId(), GetInflightBuffersUnsafe(barrier.GetId()));
}
void RemoteInputChannel::CheckpointStopped(long checkpointId)
{
std::lock_guard<std::recursive_mutex> lock(queueMutex);
channelStatePersister->StopPersisting(checkpointId);
if (lastBarrierId_ == checkpointId) {
ResetLastBarrier();
}
}
std::vector<Buffer*> RemoteInputChannel::GetInflightBuffersUnsafe(long checkpointId)
{
std::lock_guard<std::recursive_mutex> lock(queueMutex);
std::vector<Buffer*> inflightBuffers;
std::queue<Buffer*> tmpQueue = dataQueue;
while (!tmpQueue.empty()) {
Buffer* buffer = tmpQueue.front();
if ((buffer != nullptr) && (buffer->isBuffer())) {
inflightBuffers.push_back(buffer->RetainBuffer());
}
tmpQueue.pop();
}
LOG("RemoteInputChannel get inflight buffers success, buffer num:" << inflightBuffers.size()
<< ", checkpointId: " << checkpointId);
return inflightBuffers;
}
}