* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_DEMULTIPLEXINGRECORDDESERIALIZER_H
#define OMNISTREAM_DEMULTIPLEXINGRECORDDESERIALIZER_H
#pragma once
#include <functional>
#include <map>
#include <memory>
#include <optional>
#include <vector>
#include <algorithm>
#include <stdexcept>
#include <sstream>
#include "streaming/runtime/streamrecord/StreamElement.h"
#include "runtime/plugable/DeserializationDelegate.h"
#include "runtime/io/network/api/serialization/RecordDeserializer.h"
#include "runtime/streamrecord/StreamRecord.h"
#include "api/watermark/Watermark.h"
#include "watermark/WatermarkStatus.h"
#include "runtime/event/SubtaskConnectionDescriptor.h"
#include "partition/consumer/InputChannelInfo.h"
#include "checkpoint/InflightDataRescalingDescriptor.h"
#include "runtime/io/recover/RecordFilter.h"
namespace omnistream {
class DemultiplexingRecordDeserializer : public RecordDeserializer {
public:
static DemultiplexingRecordDeserializer *UNMAPPED;
struct VirtualChannel {
std::shared_ptr<RecordDeserializer> deserializer;
std::function<bool(StreamRecord &)> recordFilter;
Watermark lastWatermark = Watermark::UNINITIALIZED;
WatermarkStatus *watermarkStatus = WatermarkStatus::active();
DeserializationResult *lastResult;
VirtualChannel(std::shared_ptr<RecordDeserializer> deser, std::function<bool(StreamRecord &)> filter)
:
deserializer(std::move(deser)),
recordFilter(filter)
{
}
DeserializationResult *getNextRecord(DeserializationDelegate &delegate)
{
do {
lastResult = &deserializer->getNextRecord(delegate);
if (lastResult->isFullRecord()) {
auto *element = static_cast<StreamElement *>(delegate.getInstance());
if (dynamic_cast<StreamRecord *>(element)) {
StreamRecord *streamRecord = dynamic_cast<StreamRecord *>(element);
if (recordFilter(*streamRecord)) {
return lastResult;
}
} else if (dynamic_cast<Watermark *>(element)) {
lastWatermark = *(dynamic_cast<Watermark *>(element));
return lastResult;
} else if (dynamic_cast<WatermarkStatus *>(element)) {
watermarkStatus = dynamic_cast<WatermarkStatus *>(element);
return lastResult;
}
}
} while (!lastResult->isBufferConsumed());
return &DeserializationResult_PARTIAL_RECORD;
}
void SetNextBuffer(ReadOnlySlicedNetworkBuffer* buffer)
{
deserializer->SetNextBuffer(buffer);
}
void clear()
{
deserializer->clear();
}
bool hasPartialData() const
{
return lastResult && !lastResult->isBufferConsumed();
}
};
explicit DemultiplexingRecordDeserializer(
std::map<long, std::shared_ptr<VirtualChannel>> channelMap)
: channels(std::move(channelMap))
{
}
void select(SubtaskConnectionDescriptor *descriptor)
{
auto it = channels.find(descriptor->getComplexId());
if (it == channels.end()) {
std::ostringstream oss;
oss << "Cannot select " << descriptor->toString() << "; known channels are ";
for (const auto &pair : channels) {
oss << pair.first << " ";
}
INFO_RELEASE(oss.str().c_str());
throw std::runtime_error(oss.str());
}
currentVirtualChannel = it->second;
}
bool hasMappings() const
{
return !channels.empty();
}
void SetNextBuffer(ReadOnlySlicedNetworkBuffer* buffer) override
{
currentVirtualChannel->SetNextBuffer(buffer);
}
std::vector<omnistream::Buffer *> GetUnconsumedBuffer() override
{
throw std::runtime_error("Cannot checkpoint while recovering");
}
bool hasPartialData() const
{
for (const auto &pair : channels) {
if (pair.second->hasPartialData()) {
return true;
}
}
return false;
}
DeserializationResult &getNextRecord(IOReadableWritable &target) override
{
auto delegate = static_cast<DeserializationDelegate *>(&target);
DeserializationResult *result;
do {
result = currentVirtualChannel->getNextRecord(*delegate);
if (result->isFullRecord()) {
auto *element = static_cast<StreamElement *>(delegate->getInstance());
if (dynamic_cast<StreamRecord *>(element)) {
return *result;
} else if (dynamic_cast<Watermark *>(element)) {
Watermark minWatermark = Watermark::MAX_WATERMARK;
for (const auto &pair : channels) {
Watermark w = pair.second->lastWatermark;
if (w.getTimestamp() < minWatermark.getTimestamp()) {
minWatermark = w;
}
}
if (minWatermark.getTimestamp() == Watermark::UNINITIALIZED.getTimestamp()) {
continue;
}
delegate->setInstance(static_cast<void *>(&minWatermark));
return *result;
} else if (dynamic_cast<WatermarkStatus *>(element)) {
for (const auto &pair : channels) {
if (pair.second->watermarkStatus->IsActive()) {
delegate->setInstance(WatermarkStatus::active());
break;
}
}
return *result;
}
}
} while (!result->isBufferConsumed());
return DeserializationResult_PARTIAL_RECORD;
}
void clear() override
{
for (auto &pair : channels) {
pair.second->clear();
}
}
static std::unique_ptr<DemultiplexingRecordDeserializer> create(
const InputChannelInfo &channelInfo, const InflightDataRescalingDescriptor &rescalingDescriptor,
std::function<std::shared_ptr<RecordDeserializer>(int)> deserializerFactory,
std::function<std::function<bool(StreamRecord &)>(const InputChannelInfo &)> recordFilterFactory)
{
std::vector<int> oldSubtaskIndexes = rescalingDescriptor.GetOldSubtaskIndexes(channelInfo.getGateIdx());
if (oldSubtaskIndexes.empty()) {
return std::make_unique<DemultiplexingRecordDeserializer>(
std::map<long,
std::shared_ptr<typename DemultiplexingRecordDeserializer::VirtualChannel>>());
}
auto channelMapping = rescalingDescriptor.GetChannelMapping(channelInfo.getGateIdx());
std::vector<int> oldChannelIndexes = channelMapping->getMappedIndexes(channelInfo.getInputChannelIdx());
if (oldChannelIndexes.empty()) {
INFO_RELEASE("DemultiplexingRecordDeserializer create old channel is empty:"<<channelInfo.toString() <<",channel:"<< channelMapping->ToString() );
return std::make_unique<DemultiplexingRecordDeserializer>(
std::map<long,
std::shared_ptr<typename DemultiplexingRecordDeserializer::VirtualChannel>>());
}
int totalChannels = oldSubtaskIndexes.size() * oldChannelIndexes.size();
std::map<long, std::shared_ptr<VirtualChannel>> virtualChannels;
for (int subtask : oldSubtaskIndexes) {
for (int channel : oldChannelIndexes) {
SubtaskConnectionDescriptor descriptor = *new SubtaskConnectionDescriptor(subtask, channel);
virtualChannels[descriptor.getComplexId()] =
std::make_shared<VirtualChannel>(deserializerFactory(totalChannels),
rescalingDescriptor.IsAmbiguous(channelInfo.getGateIdx(),
subtask) ?
recordFilterFactory(channelInfo) :
RecordFilter::all());
}
}
INFO_RELEASE("DemultiplexingRecordDeserializer create channel size:" << virtualChannels.size());
return std::make_unique<DemultiplexingRecordDeserializer>(virtualChannels);
}
std::string toString() const
{
std::ostringstream oss;
oss << "DemultiplexingRecordDeserializer{channels=";
for (const auto &pair : channels) {
oss << pair.first << " ";
}
oss << "}";
return oss.str();
}
private:
std::map<long, std::shared_ptr<VirtualChannel>> channels;
std::shared_ptr<VirtualChannel> currentVirtualChannel;
};
}
#endif