#include <gtest/gtest.h>
#include "streaming/runtime/io/checkpointing/AbstractAlternatingAlignedBarrierHandlerState.h"
#include "streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrier.h"
#include "streaming/runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.h"
#include "runtime/io/network/api/CheckpointBarrier.h"
class TestCheckpointableInput : public CheckpointableInput {
public:
void BlockConsumption(const InputChannelInfo&) override { blocked = true; }
void ResumeConsumption(const InputChannelInfo&) override { resumed = true; }
void ConvertToPriorityEvent(int, int) override {}
std::vector<InputChannelInfo> GetChannelInfos() override { return { InputChannelInfo(GetInputGateIndex(), 0) };}
int GetNumberOfInputChannels() { return 1; }
void CheckpointStarted(const CheckpointBarrier&) override { started = true; }
void CheckpointStopped(long) override { stopped = true; }
int GetInputGateIndex() override { return 0; }
bool blocked = false;
bool resumed = false;
bool started = false;
bool stopped = false;
};
class TestController : public Controller {
public:
bool AllBarriersReceived() const override { return allReceived; }
const CheckpointBarrier* GetPendingCheckpointBarrier() const override { return nullptr; }
void TriggerGlobalCheckpoint(const CheckpointBarrier&) override { triggered = true; }
void InitInputsCheckpoint(const CheckpointBarrier&) override { initialized = true; }
bool IsTimedOut(const CheckpointBarrier&) override { return timedOut; }
bool allReceived = false;
bool triggered = false;
bool initialized = false;
bool timedOut = false;
};
TEST(BarrierStateTest, AlignedCheckpointInitialization) {
TestCheckpointableInput input0, input1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
TestController controller;
auto checkpointType = CheckpointType::CHECKPOINT;
auto targetLocation = CheckpointStorageLocationReference::GetDefault();
CheckpointOptions* options = new CheckpointOptions(checkpointType, targetLocation);
CheckpointBarrier barrier(101, 999999L, options);
auto* statePtr = new AlternatingWaitingForFirstBarrier(state);
BarrierHandlerState* nextState = statePtr->BarrierReceived(
&controller, InputChannelInfo(0, 0), &barrier, true);
EXPECT_TRUE(input0.blocked);
EXPECT_FALSE(controller.initialized);
EXPECT_FALSE(controller.triggered);
controller.allReceived = true;
BarrierHandlerState* finalState = nextState->BarrierReceived(
&controller, InputChannelInfo(1, 0), &barrier, true);
EXPECT_TRUE(controller.initialized);
EXPECT_TRUE(controller.triggered);
EXPECT_TRUE(input0.resumed);
EXPECT_TRUE(input1.resumed);
EXPECT_FALSE(input0.stopped);
EXPECT_FALSE(input1.stopped);
delete statePtr;
delete nextState;
delete finalState;
}
TEST(BarrierStateTest, DISABLED_FallbackToUnalignedOnTimeout) {
TestCheckpointableInput input0, input1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
TestController controller;
auto checkpointType = CheckpointType::CHECKPOINT;
auto targetLocation = CheckpointStorageLocationReference::GetDefault();
CheckpointOptions* options = new CheckpointOptions(checkpointType, targetLocation);
CheckpointBarrier barrier(123, 456789L, options);
BarrierHandlerState* statePtr = new AlternatingWaitingForFirstBarrier(state);
BarrierHandlerState* nextState = statePtr->BarrierReceived(
&controller, InputChannelInfo(0, 0), &barrier, true);
EXPECT_TRUE(input0.blocked);
EXPECT_FALSE(controller.initialized);
EXPECT_FALSE(controller.triggered);
controller.timedOut = true;
BarrierHandlerState* unalignedState = nextState->BarrierReceived(
&controller, InputChannelInfo(1, 0), &barrier, true);
EXPECT_TRUE(input0.started);
EXPECT_TRUE(input1.started);
EXPECT_TRUE(controller.initialized);
EXPECT_TRUE(controller.triggered);
controller.initialized = false;
controller.triggered = false;
input0.resumed = input1.resumed = false;
input0.stopped = input1.stopped = false;
controller.allReceived = true;
BarrierHandlerState* finalState = unalignedState->BarrierReceived(
&controller, InputChannelInfo(0, 0), &barrier, false);
EXPECT_TRUE(input0.resumed);
EXPECT_TRUE(input1.resumed);
EXPECT_TRUE(input0.stopped);
EXPECT_TRUE(input1.stopped);
delete statePtr;
delete nextState;
delete unalignedState;
delete finalState;
}
TEST(BarrierStateTest, DISABLED_WaitingUnalignedOnImmediateTimeout) {
TestCheckpointableInput input0, input1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
TestController controller;
controller.timedOut = true;
CheckpointOptions* options = new CheckpointOptions(
CheckpointType::CHECKPOINT,
CheckpointStorageLocationReference::GetDefault());
CheckpointBarrier barrier(999, 111L, options);
BarrierHandlerState* statePtr = new AlternatingWaitingForFirstBarrier(state);
BarrierHandlerState* unalignedWaiting =
statePtr->BarrierReceived(&controller, InputChannelInfo(0,0), &barrier, true);
EXPECT_FALSE(input0.resumed);
EXPECT_FALSE(input1.resumed);
EXPECT_TRUE(input0.started);
EXPECT_TRUE(input1.started);
EXPECT_TRUE(controller.initialized);
EXPECT_TRUE(controller.triggered);
delete statePtr;
delete unalignedWaiting;
}
TEST(BarrierStateTest, WaitingUnalignedFinishCheckpoint) {
TestCheckpointableInput input0, input1;
std::vector<CheckpointableInput*> inputs = {&input0, &input1};
ChannelState state(inputs);
TestController controller;
controller.allReceived = true;
CheckpointOptions* options = new CheckpointOptions(
CheckpointType::CHECKPOINT,
CheckpointStorageLocationReference::GetDefault());
CheckpointBarrier barrier(555, 777L, options);
AlternatingWaitingForFirstBarrierUnaligned unalignedState(true, state);
BarrierHandlerState* next = unalignedState.BarrierReceived(
&controller,
InputChannelInfo(0, 0),
&barrier,
true);
EXPECT_TRUE(input0.started);
EXPECT_TRUE(input1.started);
EXPECT_TRUE(input0.stopped);
EXPECT_TRUE(input1.stopped);
EXPECT_TRUE(input0.resumed);
EXPECT_FALSE(input1.resumed);
auto* asAlignedWaiting = dynamic_cast<AlternatingWaitingForFirstBarrier*>(next);
EXPECT_NE(asAlignedWaiting, nullptr);
delete next;
}