#include <gtest/gtest.h>
#include "core/fs/Path.h"
#include "runtime/executiongraph/JobIDPOD.h"
#include "runtime/buffer/ObjectBuffer.h"
#include "runtime/state/CheckpointStorage.h"
#include "runtime/state/filesystem/FsCheckpointStorageAccess.h"
#include "runtime/checkpoint/channel/ChannelStateSerializer.h"
#include "runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.h"
using namespace omnistream;
class ChannelStateSerializerImplTest : public ChannelStateSerializerImpl {
public:
void WriteHeader(std::ostringstream& dataStream) override
{
writeHeaderCalled = true;
}
void WriteData(std::ostringstream& dataStream, std::shared_ptr<Buffer> buffers) override
{
writeDataCalled = true;
}
int64_t GetHeaderLength() const override
{
return 999;
}
bool writeHeaderCalled = false;
bool writeDataCalled = false;
};
class CheckpointStorageTest : public CheckpointStorage {
public:
std::shared_ptr<CheckpointStorageAccess> createCheckpointStorage(const JobIDPOD& jobId) override
{
const std::string dir = "";
Path* checkpointDir = new Path(dir);
Path* savepointDir = new Path(dir);
return std::make_shared<FsCheckpointStorageAccess>(checkpointDir, savepointDir, jobId, 100, 100);
}
};
class ObjectBufferTest : public ObjectBuffer {
public:
bool isBuffer() const override
{
return true;
}
std::shared_ptr<BufferRecycler> GetRecycler() override
{
return DummyObjectBufferRecycler::getInstance();
}
void RecycleBuffer() override
{
recycled = true;
}
bool IsRecycled() const override
{
return recycled;
}
Buffer* RetainBuffer() override
{
return this;
}
Buffer* ReadOnlySlice() override
{
return this;
}
Buffer* ReadOnlySlice(int index, int length) override
{
return this;
}
int GetMaxCapacity() const override
{
return 1024;
}
int GetReaderIndex() const override
{
return 0;
}
void SetReaderIndex(int readerIndex) override
{
}
int GetSize() const override
{
return 0;
}
void SetSize(int writerIndex) override
{
}
int ReadableObjects() const override
{
return 0;
}
bool IsCompressed() const override
{
return false;
}
void SetCompressed(bool isCompressed) override
{
}
ObjectBufferDataType GetDataType() const override
{
return ObjectBufferDataType::DATA_BUFFER;
}
void SetDataType(ObjectBufferDataType dataType) override
{
}
int RefCount() const override
{
return 1;
}
std::string ToDebugString(bool includeHash) const override
{
return "ObjectBufferTest";
}
ObjectSegment* GetObjectSegment() override
{
return std::make_shared<ObjectSegment>(0).get();
}
int GetBufferType() override
{
return 42;
}
std::pair<uint8_t*, size_t> GetBytes() override
{
return {nullptr, 0};
}
private:
bool recycled = false;
};
TEST(ChannelStateWriteRequestDispatcherImplTest, InitialiseRequestDispatcher)
{
auto storage = std::make_shared<CheckpointStorageTest>();
auto serializer = std::make_shared<ChannelStateSerializerImplTest>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jobVertexID(-1, -1);
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult;
CheckpointStorageLocationReference locationReference;
auto registerRequest = std::make_shared<SubtaskRegisterRequest>(jobVertexID, 1);
dispatcher->dispatch(registerRequest);
auto startRequest = std::make_shared<CheckpointStartRequest>(jobVertexID, 1, 0, targetResult, &locationReference);
dispatcher->dispatch(startRequest);
EXPECT_TRUE(serializer->writeHeaderCalled);
auto releaseRequest = std::make_shared<SubtaskReleaseRequest>(jobVertexID, 1);
dispatcher->dispatch(releaseRequest);
}
TEST(ChannelStateWriteRequestDispatcherImplTest, AbortedCheckpointIsCancelledNotThrown)
{
auto storage = std::make_shared<CheckpointStorageTest>();
auto serializer = std::make_shared<ChannelStateSerializerImplTest>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jvid(-1, -1);
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult =
ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
dispatcher->dispatch(ChannelStateWriteRequest::registerSubtask(jvid, 1));
std::shared_ptr<ChannelStateWriteRequest> startRequest = ChannelStateWriteRequest::start(jvid, 1, 1, "Start");
dispatcher->dispatch(startRequest);
auto oldStartRequest = ChannelStateWriteRequest::start(jvid, 1, 0, "Start");
EXPECT_NO_THROW(dispatcher->dispatch(oldStartRequest));
}
TEST(ChannelStateWriteRequestDispatcherImplTest, FullRequestFlow)
{
auto storage = std::make_shared<CheckpointStorageTest>();
auto serializer = std::make_shared<ChannelStateSerializerImplTest>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
storage, JobIDPOD(-1, -1), serializer, storage->createCheckpointStorage(JobIDPOD(-1, -1)));
JobVertexID jobVertexID(-1, -1);
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> targetResult =
ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
CheckpointStorageLocationReference locationReference;
std::shared_ptr<SubtaskRegisterRequest> registerRequest = std::make_shared<SubtaskRegisterRequest>(jobVertexID, 1);
dispatcher->dispatch(registerRequest);
auto startRequest = std::make_shared<CheckpointStartRequest>(jobVertexID, 1, 1, targetResult, &locationReference);
dispatcher->dispatch(startRequest);
auto completeInputRequest = ChannelStateWriteRequest::completeInput(jobVertexID, 1, 1);
dispatcher->dispatch(completeInputRequest);
auto completeOutputRequest = ChannelStateWriteRequest::completeOutput(jobVertexID, 1, 1);
dispatcher->dispatch(completeOutputRequest);
auto releaseRequest = std::make_shared<SubtaskReleaseRequest>(jobVertexID, 1);
dispatcher->dispatch(releaseRequest);
auto oldCheckpointRequest = ChannelStateWriteRequest::completeInput(jobVertexID, 1, 0);
dispatcher->dispatch(oldCheckpointRequest);
auto abortRequest = ChannelStateWriteRequest::terminate(
jobVertexID, 1, 1, std::make_exception_ptr(std::runtime_error("Test error")));
dispatcher->dispatch(abortRequest);
auto invalidStartRequest =
std::make_shared<CheckpointStartRequest>(jobVertexID, 1, 0, targetResult, &locationReference);
dispatcher->dispatch(invalidStartRequest);
auto unregisterRequest = std::make_shared<SubtaskReleaseRequest>(jobVertexID, 2);
}