* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include <gtest/gtest.h>
#include <set>
#include <memory>
#include <stdexcept>
#include <vector>
#include "runtime/checkpoint/channel/ChannelStatePendingResult.h"
#include "runtime/checkpoint/channel/ChannelStateWriteRequest.h"
#include "runtime/checkpoint/channel/ChannelStateWriter.h"
#include "runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.h"
#include "runtime/checkpoint/channel/ChannelStateSerializer.h"
#include "runtime/checkpoint/channel/ChannelStateCheckpointWriter.h"
#include "runtime/buffer/ObjectBufferRecycler.h"
#include "runtime/state/CheckpointStorage.h"
#include "runtime/state/filesystem/FsCheckpointStorageAccess.h"
#include "runtime/executiongraph/JobIDPOD.h"
#include "streaming/runtime/io/checkpointing/ChannelState.h"
#include "streaming/runtime/io/checkpointing/AbstractAlternatingAlignedBarrierHandlerState.h"
#include "streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrier.h"
#include "streaming/runtime/io/checkpointing/AlternatingCollectingBarriers.h"
#include "runtime/io/network/api/CheckpointBarrier.h"
#include "runtime/io/checkpointing/CheckpointBarrierHandler.h"
#include "core/utils/threads/CompletableFutureV2.h"
using namespace omnistream;
class MockSerializer : public ChannelStateSerializerImpl {
public:
int64_t GetHeaderLength() const override
{
return 100;
}
};
class MockStorage : public CheckpointStorage {
public:
std::shared_ptr<CheckpointStorageAccess> createCheckpointStorage(const JobIDPOD& jobId) override
{
static Path checkpointDir("");
static Path savepointDir("");
return std::make_shared<FsCheckpointStorageAccess>(&checkpointDir, &savepointDir, jobId, 100, 100);
}
};
class MockBuffer : public ObjectBuffer {
public:
bool isBuffer() const override
{
return true;
}
std::shared_ptr<BufferRecycler> GetRecycler() override
{
return DummyObjectBufferRecycler::getInstance();
}
void RecycleBuffer() override
{
recycled = true;
}
bool IsRecycled() const override
{
return recycled;
}
Buffer* RetainBuffer() override
{
return this;
}
Buffer* ReadOnlySlice() override
{
return this;
}
Buffer* ReadOnlySlice(int, int) override
{
return this;
}
int GetMaxCapacity() const override
{
return 512;
}
int GetReaderIndex() const override
{
return 0;
}
void SetReaderIndex(int) override
{
}
int GetSize() const override
{
return 128;
}
void SetSize(int) override
{
}
int ReadableObjects() const override
{
return 1;
}
bool IsCompressed() const override
{
return false;
}
void SetCompressed(bool) override
{
}
ObjectBufferDataType GetDataType() const override
{
return ObjectBufferDataType::DATA_BUFFER;
}
void SetDataType(ObjectBufferDataType) override
{
}
int RefCount() const override
{
return 1;
}
std::string ToDebugString(bool) const override
{
return "MockBuffer";
}
ObjectSegment* GetObjectSegment() override
{
return nullptr;
}
int GetBufferType() override
{
return 0;
}
std::pair<uint8_t*, size_t> GetBytes() override
{
return {nullptr, 0};
}
bool recycled = false;
};
TEST(ChannelStatePendingResultTest, InitialStateNotDone)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
EXPECT_FALSE(pending.IsDone());
EXPECT_FALSE(pending.IsAllInputsReceived());
EXPECT_FALSE(pending.IsAllOutputsReceived());
EXPECT_EQ(pending.GetSubtaskIndex(), 0);
EXPECT_EQ(pending.GetCheckpointId(), 1);
}
TEST(ChannelStatePendingResultTest, CompleteInputSetsFlag)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
pending.CompleteInput();
EXPECT_TRUE(pending.IsAllInputsReceived());
}
TEST(ChannelStatePendingResultTest, CompleteInputTwiceThrows)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
pending.CompleteInput();
EXPECT_THROW(pending.CompleteInput(), std::logic_error);
}
TEST(ChannelStatePendingResultTest, CompleteOutputSetsFlag)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
pending.CompleteOutput();
EXPECT_TRUE(pending.IsAllOutputsReceived());
}
TEST(ChannelStatePendingResultTest, CompleteOutputTwiceThrows)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
pending.CompleteOutput();
EXPECT_THROW(pending.CompleteOutput(), std::logic_error);
}
TEST(ChannelStatePendingResultTest, FailDoesNotThrow)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
EXPECT_NO_THROW(pending.Fail(std::make_exception_ptr(std::runtime_error("test"))));
}
TEST(ChannelStatePendingResultTest, GetInputChannelOffsetsStartsEmpty)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
EXPECT_TRUE(pending.GetInputChannelOffsets().empty());
}
TEST(ChannelStatePendingResultTest, GetResultSubpartitionOffsetsStartsEmpty)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
EXPECT_TRUE(pending.GetResultSubpartitionOffsets().empty());
}
TEST(ChannelStatePendingResultTest, IsDoneReflectsResultState)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
auto serializer = std::make_shared<MockSerializer>();
ChannelStatePendingResult pending(0, 1, result, serializer);
EXPECT_FALSE(pending.IsDone());
}
TEST(CheckpointInProgressRequestTest, ConstructorWithAction)
{
CheckpointInProgressRequest req(
"Test", JobVertexID(1, 1), 0, 42, [](std::shared_ptr<ChannelStateCheckpointWriter>&) {});
EXPECT_EQ(req.getCheckpointId(), 42);
EXPECT_EQ(req.getSubtaskIndex(), 0);
EXPECT_EQ(req.getName(), "Test");
}
TEST(CheckpointInProgressRequestTest, ExecuteTransitionsState)
{
CheckpointInProgressRequest req(
"Test", JobVertexID(1, 1), 0, 1, [](std::shared_ptr<ChannelStateCheckpointWriter>&) {});
std::shared_ptr<ChannelStateCheckpointWriter> writer = nullptr;
EXPECT_NO_THROW(req.execute(writer));
}
TEST(CheckpointInProgressRequestTest, ExecuteTwiceThrows)
{
CheckpointInProgressRequest req(
"Test", JobVertexID(1, 1), 0, 1, [](std::shared_ptr<ChannelStateCheckpointWriter>&) {});
std::shared_ptr<ChannelStateCheckpointWriter> writer = nullptr;
req.execute(writer);
EXPECT_THROW(req.execute(writer), std::runtime_error);
}
TEST(CheckpointInProgressRequestTest, CancelCallsDiscardAction)
{
bool discarded = false;
CheckpointInProgressRequest req(
"Test",
JobVertexID(1, 1),
0,
1,
[](std::shared_ptr<ChannelStateCheckpointWriter>&) {},
[&](const std::exception_ptr&) { discarded = true; });
req.cancel(std::make_exception_ptr(std::runtime_error("cancel test")));
EXPECT_TRUE(discarded);
}
TEST(CheckpointInProgressRequestTest, GetReadyFutureReturnsCompletedWhenNoFutureProvided)
{
CheckpointInProgressRequest req(
"Test", JobVertexID(1, 1), 0, 1, [](std::shared_ptr<ChannelStateCheckpointWriter>&) {});
auto future = req.getReadyFuture();
EXPECT_TRUE(future->IsDone());
}
TEST(CheckpointInProgressRequestTest, GetReadyFutureReflectsProvidedFuture)
{
auto ready = std::make_shared<CompletableFutureV2<void>>();
CheckpointInProgressRequest req(
"Test",
JobVertexID(1, 1),
0,
1,
[](std::shared_ptr<ChannelStateCheckpointWriter>&) {},
[](const std::exception_ptr&) {},
ready);
EXPECT_FALSE(req.getReadyFuture()->IsDone());
ready->Complete();
EXPECT_TRUE(req.getReadyFuture()->IsDone());
}
TEST(ChannelStateWriteRequestTest, WriteInputDiscardRecyclesBuffers)
{
auto buf = std::make_shared<MockBuffer>();
std::vector<Buffer*> buffers = {buf.get()};
auto req = ChannelStateWriteRequest::writeInput(JobVertexID(1, 1), 0, 1, InputChannelInfo{}, buffers);
EXPECT_FALSE(buf->recycled);
req->cancel(std::make_exception_ptr(std::runtime_error("cancel")));
EXPECT_TRUE(buf->recycled);
}
TEST(ChannelStateWriteRequestTest, WriteOutputDiscardRecyclesBuffers)
{
auto buf = std::make_shared<MockBuffer>();
std::vector<Buffer*> buffers = {buf.get()};
auto req = ChannelStateWriteRequest::writeOutput(JobVertexID(1, 1), 0, 1, ResultSubpartitionInfoPOD{}, buffers);
EXPECT_FALSE(buf->recycled);
req->cancel(std::make_exception_ptr(std::runtime_error("cancel")));
EXPECT_TRUE(buf->recycled);
}
TEST(ChannelStateWriteRequestTest, CompleteInputHasDoneFuture)
{
auto req = ChannelStateWriteRequest::completeInput(JobVertexID(1, 1), 0, 1);
EXPECT_TRUE(req->getReadyFuture()->IsDone());
EXPECT_EQ(req->getName(), "CheckpointCompleteInput");
}
TEST(ChannelStateWriteRequestTest, CompleteOutputHasDoneFuture)
{
auto req = ChannelStateWriteRequest::completeOutput(JobVertexID(1, 1), 0, 1);
EXPECT_TRUE(req->getReadyFuture()->IsDone());
EXPECT_EQ(req->getName(), "CheckpointCompleteOutput");
}
TEST(ChannelStateWriteRequestTest, RegisterSubtaskCreatesCorrectType)
{
auto req = ChannelStateWriteRequest::registerSubtask(JobVertexID(1, 1), 0);
EXPECT_EQ(req->getName(), "Register");
EXPECT_EQ(req->getSubtaskIndex(), 0);
}
TEST(ChannelStateWriteRequestTest, ReleaseSubtaskCreatesCorrectType)
{
auto req = ChannelStateWriteRequest::releaseSubtask(JobVertexID(1, 1), 0);
EXPECT_EQ(req->getName(), "Release");
EXPECT_EQ(req->getSubtaskIndex(), 0);
}
TEST(ChannelStateWriteRequestDispatcherImplTest, AbortedCheckpointIdDetection)
{
auto storage = std::make_shared<MockStorage>();
auto serializer = std::make_shared<MockSerializer>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jvid(1, 1);
auto targetResult = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
dispatcher->dispatch(ChannelStateWriteRequest::registerSubtask(jvid, 1));
auto startReq = ChannelStateWriteRequest::start(jvid, 1, 5, targetResult, "Start");
dispatcher->dispatch(startReq);
auto abortReq =
ChannelStateWriteRequest::terminate(jvid, 1, 5, std::make_exception_ptr(std::runtime_error("abort")));
EXPECT_NO_THROW(dispatcher->dispatch(abortReq));
}
TEST(ChannelStateWriteRequestDispatcherImplTest, SubsumedCheckpointIsCancelled)
{
auto storage = std::make_shared<MockStorage>();
auto serializer = std::make_shared<MockSerializer>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jvid(1, 1);
auto targetResult = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
dispatcher->dispatch(ChannelStateWriteRequest::registerSubtask(jvid, 1));
auto startReq = ChannelStateWriteRequest::start(jvid, 1, 10, targetResult, "Start");
dispatcher->dispatch(startReq);
auto oldStartReq = ChannelStateWriteRequest::start(jvid, 1, 5, targetResult, "Start");
EXPECT_NO_THROW(dispatcher->dispatch(oldStartReq));
}
TEST(ChannelStateWriteRequestDispatcherImplTest, NoStartBeforeAbortDoesNotThrow)
{
auto storage = std::make_shared<MockStorage>();
auto serializer = std::make_shared<MockSerializer>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jvid(1, 1);
dispatcher->dispatch(ChannelStateWriteRequest::registerSubtask(jvid, 1));
auto abortReq =
ChannelStateWriteRequest::terminate(jvid, 1, 1, std::make_exception_ptr(std::runtime_error("early abort")));
EXPECT_NO_THROW(dispatcher->dispatch(abortReq));
}
TEST(ChannelStateWriteRequestDispatcherImplTest, ReleaseUnregistersSubtask)
{
auto storage = std::make_shared<MockStorage>();
auto serializer = std::make_shared<MockSerializer>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jvid(1, 1);
auto registerReq = std::make_shared<SubtaskRegisterRequest>(jvid, 1);
EXPECT_NO_THROW(dispatcher->dispatch(registerReq));
auto releaseReq = std::make_shared<SubtaskReleaseRequest>(jvid, 1);
EXPECT_NO_THROW(dispatcher->dispatch(releaseReq));
}
class MockCheckpointableInput : public CheckpointableInput {
public:
void BlockConsumption(const InputChannelInfo&) override
{
blocked = true;
}
void ResumeConsumption(const InputChannelInfo&) override
{
resumed = true;
}
void ConvertToPriorityEvent(int, int) override
{
prioritized = true;
}
std::vector<InputChannelInfo> GetChannelInfos() override
{
return {InputChannelInfo(GetInputGateIndex(), 0)};
}
void CheckpointStarted(const CheckpointBarrier&) override
{
}
void CheckpointStopped(long) override
{
}
int GetInputGateIndex() override
{
return gateIdx;
}
int gateIdx = 0;
bool blocked = false;
bool resumed = false;
bool prioritized = false;
};
TEST(ChannelStateTest, BlockChannelCallsBlockOnInput)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
EXPECT_TRUE(input.blocked);
}
TEST(ChannelStateTest, UnblockAllChannelsResumesBlockedInputs)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
state.UnblockAllChannels();
EXPECT_TRUE(input.resumed);
}
TEST(ChannelStateTest, UnblockAllDoesNotResumeAlreadyResumed)
{
MockCheckpointableInput input0, input1;
input1.gateIdx = 1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
state.UnblockAllChannels();
EXPECT_TRUE(input0.resumed);
EXPECT_FALSE(input1.resumed);
}
TEST(ChannelStateTest, ChannelFinishedRemovesBlockedState)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
state.ChannelFinished(InputChannelInfo(0, 0));
EXPECT_NO_THROW(state.EmptyState());
}
TEST(ChannelStateTest, EmptyStateThrowsWhenBlockedChannelsExist)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
EXPECT_THROW(state.EmptyState(), std::runtime_error);
}
TEST(ChannelStateTest, EmptyStateSucceedsAfterUnblockAll)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.BlockChannel(InputChannelInfo(0, 0));
state.UnblockAllChannels();
EXPECT_NO_THROW(state.EmptyState());
}
TEST(ChannelStateTest, PrioritizeAllAnnouncementsCallsConvert)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.addSeenAnnouncement(InputChannelInfo(0, 0), 42);
state.PrioritizeAllAnnouncements();
EXPECT_TRUE(input.prioritized);
}
TEST(ChannelStateTest, AnnouncementSeenAndRemoved)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
state.addSeenAnnouncement(InputChannelInfo(0, 0), 7);
state.removeSeenAnnouncement(InputChannelInfo(0, 0));
EXPECT_NO_THROW(state.PrioritizeAllAnnouncements());
}
TEST(ChannelStateTest, GetInputsReturnsCorrect)
{
MockCheckpointableInput input0, input1;
input1.gateIdx = 1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
auto result = state.getInputs();
EXPECT_EQ(result.size(), 2u);
}
class MockController : public Controller {
public:
bool AllBarriersReceived() const override
{
return allReceived;
}
const CheckpointBarrier* GetPendingCheckpointBarrier() const override
{
return nullptr;
}
void TriggerGlobalCheckpoint(const CheckpointBarrier&) override
{
triggered = true;
}
void InitInputsCheckpoint(const CheckpointBarrier&) override
{
initialized = true;
}
bool IsTimedOut(const CheckpointBarrier&) override
{
return timedOut;
}
bool allReceived = false;
bool triggered = false;
bool initialized = false;
bool timedOut = false;
};
class TestBarrierHandler : public AbstractAlternatingAlignedBarrierHandlerState {
public:
TestBarrierHandler(ChannelState state) : AbstractAlternatingAlignedBarrierHandlerState(std::move(state))
{
}
protected:
BarrierHandlerState* TransitionAfterBarrierReceived(ChannelState s) override
{
return new AlternatingCollectingBarriers(std::move(s));
}
BarrierHandlerState* AlignedCheckpointTimeout(Controller*, CheckpointBarrier*) override
{
return new AlternatingWaitingForFirstBarrier(state.EmptyState());
}
};
TEST(BarrierHandlerStateTest, SingleChannelBarrierReceivedTriggersCheckpoint)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
MockController controller;
std::unique_ptr<TestBarrierHandler> handler(new TestBarrierHandler(std::move(state)));
controller.allReceived = true;
auto options = std::make_shared<CheckpointOptions>(
CheckpointType::CHECKPOINT, CheckpointStorageLocationReference::GetDefault());
CheckpointBarrier barrier(1, 1000, options);
std::unique_ptr<BarrierHandlerState> next(
handler->BarrierReceived(&controller, InputChannelInfo(0, 0), &barrier, true));
EXPECT_TRUE(controller.initialized);
EXPECT_TRUE(controller.triggered);
EXPECT_TRUE(input.resumed);
}
TEST(BarrierHandlerStateTest, BarrierReceivedBlocksChannelWhenMarked)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
MockController controller;
std::unique_ptr<TestBarrierHandler> handler(new TestBarrierHandler(std::move(state)));
auto options = std::make_shared<CheckpointOptions>(
CheckpointType::CHECKPOINT, CheckpointStorageLocationReference::GetDefault());
CheckpointBarrier barrier(1, 1000, options);
std::unique_ptr<BarrierHandlerState> next(
handler->BarrierReceived(&controller, InputChannelInfo(0, 0), &barrier, true));
EXPECT_TRUE(input.blocked);
}
TEST(BarrierHandlerStateTest, BarrierReceivedDoesNotBlockWhenNotMarked)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
MockController controller;
std::unique_ptr<TestBarrierHandler> handler(new TestBarrierHandler(std::move(state)));
auto options = std::make_shared<CheckpointOptions>(
CheckpointType::CHECKPOINT, CheckpointStorageLocationReference::GetDefault());
CheckpointBarrier barrier(1, 1000, options);
std::unique_ptr<BarrierHandlerState> next(
handler->BarrierReceived(&controller, InputChannelInfo(0, 0), &barrier, false));
EXPECT_FALSE(input.blocked);
}
TEST(BarrierHandlerStateTest, AnnouncementReceivedRecordsState)
{
MockCheckpointableInput input;
std::vector<CheckpointableInput*> inputs = {&input};
ChannelState state(inputs);
MockController controller;
std::unique_ptr<TestBarrierHandler> handler(new TestBarrierHandler(std::move(state)));
auto next = handler->AnnouncementReceived(&controller, InputChannelInfo(0, 0), 99);
EXPECT_EQ(next, handler.get());
}
TEST(ChannelStateWriteResultTest, CreateEmptyIsNotDone)
{
auto result = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
EXPECT_FALSE(result->IsDone());
}
TEST(ChannelStateWriteResultTest, IsDoneAfterBothFuturesComplete)
{
auto inputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::InputChannelStateHandleVecPtr>>();
auto outputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::ResultSubpartitionStateVecPtr>>();
ChannelStateWriter::ChannelStateWriteResult result(inputFuture, outputFuture);
EXPECT_FALSE(result.IsDone());
inputFuture->Complete(std::make_shared<std::vector<std::shared_ptr<InputChannelStateHandle>>>());
EXPECT_FALSE(result.IsDone());
outputFuture->Complete(std::make_shared<std::vector<std::shared_ptr<ResultSubpartitionStateHandle>>>());
EXPECT_TRUE(result.IsDone());
}
TEST(ChannelStateWriteResultTest, FailCancelsBothFutures)
{
auto inputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::InputChannelStateHandleVecPtr>>();
auto outputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::ResultSubpartitionStateVecPtr>>();
ChannelStateWriter::ChannelStateWriteResult result(inputFuture, outputFuture);
result.Fail(std::make_exception_ptr(std::runtime_error("fail")));
EXPECT_TRUE(result.GetInputChannelStateHandles()->IsCancelled());
EXPECT_TRUE(result.GetResultSubpartitionStateHandles()->IsCancelled());
}
TEST(ChannelStateWriteResultTest, IsDoneFalseWhenFuturesNotComplete)
{
auto inputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::InputChannelStateHandleVecPtr>>();
auto outputFuture = std::make_shared<
CompletableFutureV2<ChannelStateWriter::ChannelStateWriteResult::ResultSubpartitionStateVecPtr>>();
ChannelStateWriter::ChannelStateWriteResult result(inputFuture, outputFuture);
EXPECT_FALSE(result.IsDone());
}
TEST(ChannelStateCheckpointWriterTest, ConstructorRequiresNonEmptySubtasks)
{
std::set<SubtaskID> empty;
EXPECT_THROW(
{ ChannelStateCheckpointWriter writer(empty, 1, nullptr, std::make_shared<MockSerializer>(), []() {}); },
std::invalid_argument);
}
TEST(SubtaskIDTest, EqualityAndOrdering)
{
SubtaskID a(JobVertexID(1, 1), 0);
SubtaskID b(JobVertexID(1, 1), 0);
SubtaskID c(JobVertexID(1, 1), 1);
SubtaskID d(JobVertexID(2, 2), 0);
EXPECT_EQ(a, b);
EXPECT_FALSE(a == c);
EXPECT_FALSE(a == d);
EXPECT_LT(a, c);
EXPECT_LT(a, d);
}
TEST(SubtaskIDTest, OfFactoryCreatesCorrectId)
{
auto sid = SubtaskID::Of(JobVertexID(3, 4), 5);
EXPECT_EQ(sid.GetJobVertexID(), JobVertexID(3, 4));
EXPECT_EQ(sid.GetSubtaskIndex(), 5);
}
TEST(SubtaskIDTest, DefaultConstructor)
{
SubtaskID sid;
EXPECT_EQ(sid.GetSubtaskIndex(), -1);
EXPECT_EQ(sid.GetJobVertexID(), JobVertexID(-1, -1));
}
TEST(CheckpointAbortRequestTest, ConstructorStoresCause)
{
auto cause = std::make_exception_ptr(std::runtime_error("abort"));
CheckpointAbortRequest req(JobVertexID(1, 1), 0, 1, cause);
EXPECT_EQ(req.getName(), "Abort");
EXPECT_EQ(req.getCheckpointId(), 1);
EXPECT_EQ(req.getSubtaskIndex(), 0);
}
TEST(CheckpointAbortRequestTest, CancelIsNoOp)
{
auto cause = std::make_exception_ptr(std::runtime_error("abort"));
CheckpointAbortRequest req(JobVertexID(1, 1), 0, 1, cause);
EXPECT_NO_THROW(req.cancel(std::make_exception_ptr(std::runtime_error("cancel"))));
}