* Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef RECOVERED_CHANNEL_STATE_HANDLER_H
#define RECOVERED_CHANNEL_STATE_HANDLER_H
#include "runtime/buffer/BufferBuilder.h"
#include "ChannelStateSerializer.h"
#include "runtime/partition/CheckpointedResultPartition.h"
#include "runtime/partition/PipelinedSubpartition.h"
#include "core/utils/lang/AutoCloseable.h"
#include "partition/ResultPartitionWriter.h"
#include "partition/consumer/InputGate.h"
#include "partition/consumer/RecoveredInputChannel.h"
#include "runtime/checkpoint/InflightDataRescalingDescriptor.h"
#include "runtime/partition/ResultSubpartitionInfoPOD.h"
#include "runtime/checkpoint/RescaleMappings.h"
namespace omnistream {
template <typename Info, typename Context>
class RecoveredChannelStateHandler: public AutoCloseable {
public:
struct BufferWithContext {
std::shared_ptr<ChannelStateByteBuffer> buffer_;
Context context_;
BufferWithContext(std::shared_ptr<ChannelStateByteBuffer> buffer, Context context) : buffer_(std::move(buffer)), context_(std::move(context)) {}
~BufferWithContext() = default;
void close() { buffer_->close(); }
};
virtual ~RecoveredChannelStateHandler() = default;
virtual BufferWithContext getBuffer(const Info& info) = 0;
* Recover the data from buffer. This method is taking over the ownership of the
* bufferWithContext and is fully responsible for cleaning it up both on the happy path and in
* case of an error.
*/
virtual void recover(const Info& info, int oldSubtaskIndex, const BufferWithContext& bufferWithContext) = 0;
virtual void close() = 0;
};
class ResultSubpartitionRecoveredStateHandler: public RecoveredChannelStateHandler<ResultSubpartitionInfoPOD, BufferBuilder *> {
public:
ResultSubpartitionRecoveredStateHandler(std::vector<std::shared_ptr<ResultPartitionWriter>> writers, bool notifyAndBlockOnCompletion, std::shared_ptr<InflightDataRescalingDescriptor> channelMapping);
~ResultSubpartitionRecoveredStateHandler() override;
BufferWithContext getBuffer(const ResultSubpartitionInfoPOD& subpartitionInfo) override;
void recover(const ResultSubpartitionInfoPOD& subpartitionInfo, int oldSubtaskIndex, const BufferWithContext& bufferWithContext) override;
void close() override;
private:
std::vector<std::shared_ptr<ResultPartitionWriter>> writers_;
bool notifyAndBlockOnCompletion_;
std::shared_ptr<InflightDataRescalingDescriptor> channelMapping_;
std::unordered_map<int, RescaleMappings> oldToNewMappings_;
std::unordered_map<ResultSubpartitionInfoPOD, std::vector<std::shared_ptr<CheckpointedResultSubpartition>>,
ResultSubpartitionInfoPODHash, IResultSubpartitionInfoPODEqual> rescaledChannels_;
std::shared_ptr<CheckpointedResultSubpartition> getSubpartition(int partitionIndex, int subPartitionIdx);
std::vector<std::shared_ptr<CheckpointedResultSubpartition>>
getMappedChannels(const ResultSubpartitionInfoPOD& subpartitionInfo);
std::vector<std::shared_ptr<CheckpointedResultSubpartition>>
calculateMapping(const ResultSubpartitionInfoPOD& info);
};
class InputChannelRecoveredStateHandler : public RecoveredChannelStateHandler<InputChannelInfo, Buffer *> {
public:
InputChannelRecoveredStateHandler(const std::vector<std::shared_ptr<InputGate>> &inputGates,
const std::shared_ptr<InflightDataRescalingDescriptor> &channelMapping)
: inputGates(inputGates), channelMapping(channelMapping) {}
~InputChannelRecoveredStateHandler() override;
BufferWithContext getBuffer(const InputChannelInfo &inputChannelInfo) override;
void recover(const InputChannelInfo &inputChannelInfo, int oldSubtaskIndex, const BufferWithContext &bufferWithContext) override;
void close() override;
private:
std::shared_ptr<RecoveredInputChannel> getChannel(int gateIndex, int subPartitionIndex);
std::vector<std::shared_ptr<RecoveredInputChannel>> calculateMapping(InputChannelInfo info);
std::vector<std::shared_ptr<RecoveredInputChannel>> getMappedChannels(InputChannelInfo channelInfo);
std::vector<std::shared_ptr<InputGate>> inputGates;
std::shared_ptr<InflightDataRescalingDescriptor> channelMapping;
std::unordered_map<InputChannelInfo, std::vector<std::shared_ptr<RecoveredInputChannel>>, InputChannelInfoHash, InputChannelInfoEqual> rescaledChannels;
std::unordered_map<int, RescaleMappings> oldToNewMappings;
};
}
#endif