* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_CHANNEL_STATE_WRITE_REQUEST_DISPATCHER_IMPL_H
#define OMNISTREAM_CHANNEL_STATE_WRITE_REQUEST_DISPATCHER_IMPL_H
#include <map>
#include <mutex>
#include <set>
#include <exception>
#include <stdexcept>
#include "ChannelStateCheckpointWriter.h"
#include "ChannelStateWriteRequest.h"
#include "ChannelStateWriteRequestDispatcher.h"
#include "runtime/executiongraph/JobIDPOD.h"
#include "runtime/state/CheckpointStorage.h"
#include "runtime/state/CheckpointStorageWorkerView.h"
#include "runtime/io/checkpointing/CheckpointException.h"
namespace omnistream {
class ChannelStateWriteRequestDispatcherImpl : public ChannelStateWriteRequestDispatcher {
public:
ChannelStateWriteRequestDispatcherImpl(
std::shared_ptr<CheckpointStorage> checkpointStorage,
const JobIDPOD &jobID,
std::shared_ptr<ChannelStateSerializer> serializer,
std::shared_ptr<CheckpointStorageWorkerView> streamFactoryResolver);
void dispatch(std::shared_ptr<ChannelStateWriteRequest> request) override;
void fail(const std::exception_ptr &cause) override;
private:
std::shared_ptr<CheckpointStorage> checkpointStorage;
JobIDPOD jobID;
std::shared_ptr<ChannelStateSerializer> serializer;
std::set<SubtaskID> registeredSubtasks;
std::mutex mutex;
long ongoingCheckpointId;
long maxAbortedCheckpointId;
SubtaskID abortedSubtaskID;
std::exception_ptr abortedCause;
std::unordered_map<uint64_t, std::shared_ptr<ChannelStateCheckpointWriter>> writers;
void RemoveWriter(long checkpointId)
{
writers.erase(checkpointId);
}
std::shared_ptr<CheckpointStorageWorkerView> streamFactoryResolver;
void dispatchInternal(std::shared_ptr<ChannelStateWriteRequest> request);
bool isAbortedCheckpoint(long checkpointId);
void handleAbortedRequest(std::shared_ptr<ChannelStateWriteRequest> request);
void handleCheckpointStartRequest(std::shared_ptr<CheckpointStartRequest> request);
void handleCheckpointInProgressRequest(std::shared_ptr<CheckpointInProgressRequest> request);
void handleCheckpointAbortRequest(std::shared_ptr<CheckpointAbortRequest> request);
void failAndClearWriter(const std::exception_ptr &e);
void failAndClearWriter(const JobVertexID &jvid, int idx, const std::exception_ptr &e);
std::shared_ptr<ChannelStateCheckpointWriter> buildWriter(std::shared_ptr<CheckpointStartRequest> request);
std::shared_ptr<CheckpointStorageWorkerView> getStreamFactoryResolver();
};
}
#endif