#include <gtest/gtest.h>
#include <condition_variable>
#include <mutex>
#include <nlohmann/json.hpp>
#include <string>
#include "streaming/runtime/partitioner/V2/ForwardPartitionerV2.h"
#include "streaming/runtime/streamrecord/StreamRecord.h"
#include "event/EndOfPartitionEvent.h"
#include "runtime/taskexecutor/TaskManagerServices.h"
#include "runtime/io/network/api/writer/V2/RecordWriterBuilderV2.h"
#include "runtime/io/network/api/writer/V2/RecordWriterV2.h"
#include "partition/consumer/InputGate.h"
#include "partition/consumer/SingleInputGate.h"
#include "runtime/executiongraph/descriptor/ExecutionAttemptIDPOD.h"
#include "runtime/io/network/api/serialization/EventSerializer.h"
#include "runtime/partition/ResultPartitionWriter.h"
#include "runtime/shuffle/ShuffleEnvironment.h"
#include "runtime/shuffle/ShuffleIOOwnerContextPOD.h"
#include "runtime/taskexecutor/OmniTaskExecutor.h"
#include "runtime/taskmanager/OmniTask.h"
#include "streaming/runtime/tasks/omni/OmniStreamTask.h"
#include "test/functionaltest/e2e/FrameworkConfig.h"
* Create 3 threads to simulate the shuffle environment
* 1. Upstream Task
* - Contains RecordWriter that writes data to ResultPartition
* 2. Downstream Task
* - Contains InputGate and InputChannel that reads data from ResultPartition
* 3. OutputFlusher
* - In FlinkSQL, OutputFlusher is a thread that flushes data to ResultPartition spawned by RecordWriter
*/
class ShuffleEnvironmentTest : public FrameworkTestConfig {
public:
void SetUp() override {
FrameworkTestConfig::SetUp();
std::ifstream taskMCSConfFile("./input/taskMCSConf.json");
if (!taskMCSConfFile.is_open()) {
throw std::runtime_error("Failed to open taskMCSConf.json");
}
try {
taskMCSConfJson_ = nlohmann::json::parse(taskMCSConfFile);
} catch (const nlohmann::json::parse_error &e) {
throw std::runtime_error("JSON parse error: " + std::string(e.what()));
}
std::ifstream jobFile("./input/job.json");
nlohmann::json jobJson = nlohmann::json::parse(jobFile);
std::ifstream sourceTaskFile("./input/sourceTask.json");
nlohmann::json sourceTaskJson = nlohmann::json::parse(sourceTaskFile);
std::cout << sourceTaskJson << std::endl;
std::ifstream sourceTddFile("./input/sourceTdd.json");
nlohmann::json sourceTddJson = nlohmann::json::parse(sourceTddFile);
std::ifstream sinkTaskFile("./input/sinkTask.json");
nlohmann::json sinkTaskJson = nlohmann::json::parse(sinkTaskFile);
std::cout << sinkTaskJson << std::endl;
std::ifstream sinkTddFile("./input/sinkTdd.json");
nlohmann::json sinkTddJson = nlohmann::json::parse(sinkTddFile);
jobInfo_ = jobJson;
srcTaskInfo_ = sourceTaskJson;
srcTddInfo_ = sourceTddJson;
sinkTaskInfo_ = sinkTaskJson;
sinkTddInfo_ = sinkTddJson;
taskManagerServices_ = std::shared_ptr<omnistream::TaskManagerServices>(omnistream::TaskManagerServices::fromConfiguration(taskMCSConfJson_));
shuffleEnv_ = taskManagerServices_->getShuffleEnvironment();
createUpstreamWriterAndDownstreamInputGate();
std::vector<omniruntime::type::DataTypeId> types = {omniruntime::type::OMNI_LONG, omniruntime::type::OMNI_LONG};
dummyVectorBatch_ = omnistream::createVectorBatch(types, 3, true);
dummyRecord_ = std::make_shared<StreamRecord>(dummyVectorBatch_);
}
void createUpstreamWriterAndDownstreamInputGate() {
auto shuffleIOOwnerContext = shuffleEnv_->createShuffleIOOwnerContext("testShuffleIOOwnerContext", omnistream::ExecutionAttemptIDPOD(), nullptr);
auto resultPartitionWriters = shuffleEnv_->createResultPartitionWriters(shuffleIOOwnerContext, srcTddInfo_.getProducedPartitions(), 1);
upstreamWriter_ = resultPartitionWriters[0];
auto inputGates = shuffleEnv_->createInputGates(shuffleIOOwnerContext, nullptr, sinkTddInfo_.getInputGates(), 1);
downstreamInputGate_ = inputGates[0];
auto partitionerName = getPartitionerName(srcTaskInfo_);
auto partitioner = getPartitionerFromPartitionerName(partitionerName);
auto bufferTimeout = srcTaskInfo_.getStreamConfigPOD().getOutEdgesInOrder()[0].getBufferTimeout();
recordWriter_ = std::shared_ptr<omnistream::RecordWriterV2>(omnistream::RecordWriterBuilderV2().withTaskName("testRecordWriter").withChannelSelector(partitioner.get()).withWriter(upstreamWriter_).withTimeout(bufferTimeout).build());
recordWriter_->postConstruct();
upstreamWriter_->setup();
downstreamInputGate_->setup();
downstreamInputGate_->RequestPartitions();
targetSubpartition_ = srcTddInfo_.getProducedPartitions()[0].getPartitionId().getPartitionNum();
}
void checkVectorBatchMatch(omniruntime::vec::VectorBatch *vecBatch, omniruntime::vec::VectorBatch *expectedVecBatch) {
EXPECT_EQ(vecBatch->GetRowCount(), expectedVecBatch->GetRowCount());
EXPECT_EQ(vecBatch->GetVectorCount(), expectedVecBatch->GetVectorCount());
for (int i = 0; i < vecBatch->GetVectorCount(); i++) {
for (int j = 0; j < vecBatch->GetRowCount(); j++) {
EXPECT_EQ(reinterpret_cast<omniruntime::vec::Vector<int64_t> *>(vecBatch->Get(i))->GetValue(j), j);
}
}
}
void emitRecordsThread() {
for (int i = 0; i < 10; i++) {
upstreamWriter_->emitRecord(dummyRecord_.get(), targetSubpartition_);
}
upstreamWriter_->finish();
recordWriter_->close();
upstreamWriter_->close();
}
int pollRecordsThread() {
int recordCount = 0;
while (true) {
auto bufferOrEvent = downstreamInputGate_->PollNext();
if (!bufferOrEvent) {
continue;
}
if (bufferOrEvent->isBuffer()) {
auto buff = reinterpret_cast<ObjectBuffer*>(bufferOrEvent->getBuffer());
auto size = buff->GetSize();
auto objSegment = buff->GetObjectSegment();
auto offset = buff->GetOffset();
for (int64_t index = offset; index < size; index++) {
StreamRecord *object = static_cast<StreamRecord*>(objSegment->getObject(index));
assert(object->getTag() == StreamElementTag::TAG_REC_WITHOUT_TIMESTAMP);
auto vecBatch = static_cast<omniruntime::vec::VectorBatch *>(object->getValue());
* Expected Output from `omniruntime::vec::VectorHelper::PrintVecBatch(vecBatch);`
* 0, 0
* 1, 1
* 2, 2
*/
checkVectorBatchMatch(vecBatch, dummyVectorBatch_);
recordCount++;
}
}
else {
auto event = bufferOrEvent->getEvent();
if (dynamic_cast<EndOfPartitionEvent *>(event.get())) {
break;
}
}
delete bufferOrEvent;
}
downstreamInputGate_->close();
return recordCount;
}
protected:
std::shared_ptr<omnistream::TaskManagerServices> taskManagerServices_;
std::shared_ptr<omnistream::ShuffleEnvironment> shuffleEnv_;
std::shared_ptr<omnistream::RecordWriterV2> recordWriter_;
std::shared_ptr<omnistream::SingleInputGate> downstreamInputGate_;
std::shared_ptr<omnistream::ResultPartitionWriter> upstreamWriter_;
omniruntime::vec::VectorBatch *dummyVectorBatch_;
std::shared_ptr<StreamRecord> dummyRecord_;
int targetSubpartition_;
std::mutex mtx_;
std::condition_variable cv_;
};
TEST_F(ShuffleEnvironmentTest, DISABLED_UpstreamSendsRecordsAndDownstreamReceives) {
std::thread emitThread([this]() {
emitRecordsThread();
});
std::thread pollThread([this]() {
int recordCount = pollRecordsThread();
EXPECT_EQ(recordCount, 10);
});
emitThread.join();
pollThread.join();
}
TEST_F(ShuffleEnvironmentTest, DISABLED_UpstreamSendsTooMuchStreamRecordAndCausesBackpressure) {
bool downstreamReady = false;
std::thread emitThread([this, &downstreamReady]() {
for (int i = 0; i < 110; i++) {
if (i == 100) {
{
std::lock_guard<std::mutex> lock(mtx_);
downstreamReady = true;
}
cv_.notify_one();
}
upstreamWriter_->emitRecord(dummyRecord_.get(), targetSubpartition_);
}
upstreamWriter_->finish();
upstreamWriter_->close();
});
std::thread pollThread([this, &downstreamReady]() {
{
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [&downstreamReady] { return downstreamReady; });
}
int recordCount = pollRecordsThread();
EXPECT_EQ(recordCount, 110);
});
emitThread.join();
pollThread.join();
}
TEST_F(ShuffleEnvironmentTest, DISABLED_UpstreamSendsEndOfInputBeforeDownstreamIsReadyToReceive) {
bool downstreamReady = false;
int i = 0;
std::thread emitThread([this, &downstreamReady, &i]() {
this->emitRecordsThread();
{
std::lock_guard<std::mutex> lock(mtx_);
downstreamReady = true;
}
cv_.notify_one();
});
std::thread pollThread([this, &downstreamReady]() {
{
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [&downstreamReady] { return downstreamReady; });
}
int recordCount = pollRecordsThread();
EXPECT_EQ(recordCount, 10);
});
emitThread.join();
pollThread.join();
}