* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_CHANNEL_STATE_WRITE_REQUEST_H
#define OMNISTREAM_CHANNEL_STATE_WRITE_REQUEST_H
#include <memory>
#include <functional>
#include <atomic>
#include <vector>
#include <stdexcept>
#include "runtime/jobgraph/JobVertexID.h"
#include "runtime/partition/consumer/InputChannelInfo.h"
#include "runtime/partition/ResultSubpartitionInfoPOD.h"
#include "runtime/state/CheckpointStorageLocationReference.h"
#include "ChannelStateWriter.h"
#include "ChannelStateCheckpointWriter.h"
#include "core/utils/threads/CompletableFutureV2.h"
#include "runtime/buffer/ObjectBuffer.h"
namespace omnistream {
enum class CheckpointInProgressRequestState {
NEW,
EXECUTING,
COMPLETED,
FAILED,
CANCELLED
};
class ChannelStateWriteRequest {
public:
ChannelStateWriteRequest() = default;
ChannelStateWriteRequest(JobVertexID jobVertexID, int subtaskIndex, long checkpointId, const std::string &name);
virtual ~ChannelStateWriteRequest() = default;
JobVertexID getJobVertexID() const;
int getSubtaskIndex() const;
long getCheckpointId() const;
std::string getName() const;
virtual std::shared_ptr<CompletableFutureV2<void>> getReadyFuture();
virtual void cancel(const std::exception_ptr &cause) = 0;
virtual void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) = 0;
static std::shared_ptr<ChannelStateWriteRequest> completeInput(
JobVertexID jobVertexID, int subtaskIndex, long checkpointId);
static std::shared_ptr<ChannelStateWriteRequest> completeOutput(
JobVertexID jobVertexID, int subtaskIndex, long checkpointId);
static std::shared_ptr<ChannelStateWriteRequest> writeInput(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
InputChannelInfo info,
std::vector<Buffer*> buffers);
static std::shared_ptr<ChannelStateWriteRequest> writeOutput(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
ResultSubpartitionInfoPOD info,
std::vector<Buffer*> buffers);
static std::shared_ptr<ChannelStateWriteRequest> writeOutputFuture(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
ResultSubpartitionInfoPOD info,
std::shared_ptr<CompletableFutureV2<std::vector<Buffer*>>> dataFuture);
static std::shared_ptr<ChannelStateWriteRequest> start(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult,
const std::string name);
static std::shared_ptr<ChannelStateWriteRequest> terminate(
JobVertexID jobVertexID, int subtaskIndex, long checkpointId, const std::exception_ptr &cause);
static std::shared_ptr<ChannelStateWriteRequest> registerSubtask(
JobVertexID jobVertexID, int subtaskIndex);
static std::shared_ptr<ChannelStateWriteRequest> releaseSubtask(
JobVertexID jobVertexID, int subtaskIndex);
private:
JobVertexID jobVertexID_;
int subtaskIndex_;
long checkpointId_;
std::string name_;
};
class CheckpointStartRequest : public ChannelStateWriteRequest {
public:
CheckpointStartRequest() = default;
virtual ~CheckpointStartRequest() = default;
CheckpointStartRequest(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult,
std::shared_ptr<CheckpointStorageLocationReference> locationReference);
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> getTargetResult();
std::shared_ptr<CheckpointStorageLocationReference> getLocationReference();
void cancel(const std::exception_ptr &cause) override;
void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) override;
private:
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult_;
std::shared_ptr<CheckpointStorageLocationReference> locationReference_;
};
class CheckpointInProgressRequest : public ChannelStateWriteRequest {
public:
using Action = std::function<void(std::shared_ptr<ChannelStateCheckpointWriter> &)>;
using DiscardAction = std::function<void(const std::exception_ptr &)>;
CheckpointInProgressRequest(
const std::string &name,
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
Action action,
DiscardAction discardAction,
std::shared_ptr<CompletableFutureV2<void>> readyFuture = nullptr);
CheckpointInProgressRequest(
const std::string &name,
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
Action action);
CheckpointInProgressRequest(
const std::string &name,
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
Action action,
DiscardAction discardAction,
std::shared_ptr<CompletableFutureV2<std::vector<ObjectBuffer>>> dataFuture);
std::shared_ptr<CompletableFutureV2<void>> getReadyFuture() override;
void cancel(const std::exception_ptr &cause) override;
void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) override;
private:
Action action_;
DiscardAction discardAction_;
std::shared_ptr<CompletableFutureV2<void>> readyFuture_;
std::shared_ptr<CompletableFutureV2<std::vector<ObjectBuffer>>> dataFuture_;
std::atomic<CheckpointInProgressRequestState> state_{CheckpointInProgressRequestState::NEW};
};
class CheckpointAbortRequest : public ChannelStateWriteRequest {
public:
CheckpointAbortRequest(
JobVertexID jobVertexID,
int subtaskIndex,
long checkpointId,
const std::exception_ptr &cause);
const std::exception_ptr &getCause() const;
void cancel(const std::exception_ptr &) override;
void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) override;
private:
std::exception_ptr cause_;
};
class SubtaskRegisterRequest : public ChannelStateWriteRequest {
public:
SubtaskRegisterRequest(JobVertexID jobVertexID, int subtaskIndex);
void cancel(const std::exception_ptr &) override;
void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) override;
};
class SubtaskReleaseRequest : public ChannelStateWriteRequest {
public:
SubtaskReleaseRequest(JobVertexID jobVertexID, int subtaskIndex);
void cancel(const std::exception_ptr &) override;
void execute(std::shared_ptr<ChannelStateCheckpointWriter> writer) override;
};
}
#endif