* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "JoinSource.h"
#include <iostream>
#include <thread>
#include <chrono>
#include <algorithm>
#include <random>
#include <climits>
long JoinSource::getRandomLong()
{
static std::random_device rd;
static std::mt19937_64 gen(rd());
std::uniform_int_distribution<long> dist(LONG_MIN, LONG_MAX);
return dist(gen);
}
std::string JoinSource::getRandomAlphanumeric(int length)
{
static const std::string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
static std::random_device rd;
static std::mt19937 gen(rd());
static std::uniform_int_distribution<> dist(0, chars.size() - 1);
std::string result;
result.reserve(length);
for (int i = 0; i < length; ++i) {
result += chars[dist(gen)];
}
return result;
}
JoinSource::JoinSource(int keysPerCheck,
int checkInterval,
int minLeftRecordsPerKey,
int maxLeftRecordsPerKey,
int minRightRecordsPerKey,
int maxRightRecordsPerKey,
int recordValueSize,
int recordKeySize,
int leftMaxDelay,
int rightMaxDelay,
long sleepTime)
: running(true),
keysPerCheck(keysPerCheck),
checkInterval(checkInterval),
minLeftRecordsPerKey(minLeftRecordsPerKey),
maxLeftRecordsPerKey(maxLeftRecordsPerKey),
minRightRecordsPerKey(minRightRecordsPerKey),
maxRightRecordsPerKey(maxRightRecordsPerKey),
recordValueSize(recordValueSize),
recordKeySize(recordKeySize),
leftMaxDelay(leftMaxDelay),
rightMaxDelay(rightMaxDelay),
sleepTime(sleepTime),
currentSubtaskIndex(0),
currentKeyId(0) {}
JoinSource::JoinSource(const nlohmann::json& configuration)
{
if (configuration.contains("configMap") && !configuration["configMap"].is_null()) {
auto configMap = configuration["configMap"];
if (configMap.contains("checkInterval")) {
checkInterval = configMap["checkInterval"];
}
if (configMap.contains("minLeftRecordsPerKey")) {
minLeftRecordsPerKey = configMap["minLeftRecordsPerKey"];
}
if (configMap.contains("maxLeftRecordsPerKey")) {
maxLeftRecordsPerKey = configMap["maxLeftRecordsPerKey"];
}
if (configMap.contains("minRightRecordsPerKey")) {
minRightRecordsPerKey = configMap["minRightRecordsPerKey"];
}
if (configMap.contains("maxRightRecordsPerKey")) {
maxRightRecordsPerKey = configMap["maxRightRecordsPerKey"];
}
if (configMap.contains("keysPerCheck")) {
keysPerCheck = configMap["keysPerCheck"];
}
if (configMap.contains("rightMaxDelay")) {
rightMaxDelay = configMap["rightMaxDelay"];
}
if (configMap.contains("leftMaxDelay")) {
leftMaxDelay = configMap["leftMaxDelay"];
}
if (configMap.contains("recordKeySize")) {
recordKeySize = configMap["recordKeySize"];
}
if (configMap.contains("recordValueSize")) {
recordValueSize = configMap["recordValueSize"];
}
if (configMap.contains("sleepTime")) {
sleepTime = configMap["sleepTime"];
}
}
}
void JoinSource::open(const Configuration ¶meters)
{
AbstractRichFunction::open(parameters);
currentSubtaskIndex = this->getRuntimeContext()->getIndexOfThisSubtask();
recordsToCollect = new std::unordered_map<long, std::vector<OriginalRecord *>>();
}
void JoinSource::run(SourceContext *ctx)
{
auto startTime = std::chrono::steady_clock::now();
while (running) {
auto loopStartTime = std::chrono::steady_clock::now();
{
ctx->getCheckpointLock()->mutex.lock();
for (int i = 0; i < keysPerCheck; i++) {
generateRecordsForKey();
}
std::unordered_map<std::string, std::pair<omnistream::VectorBatch *, omnistream::VectorBatch *>> batchesToCollect;
auto currentTimestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
for (auto &entry : *recordsToCollect) {
for (auto &record : entry.second) {
if (record->getTimestamp() <= currentTimestamp) {
auto key = record->getKey();
if (batchesToCollect.find(key) == batchesToCollect.end()) {
batchesToCollect[key] = std::make_pair(createBatch(record->getLeftTotalCount()), createBatch(record->getRightTotalCount()));
}
if (record->isLeft()) {
originalRecordToBatch(record, batchesToCollect[key].first, record->getCurrentLeftId() - 1);
} else {
originalRecordToBatch(record, batchesToCollect[key].second, record->getCurrentRightId() - 1);
}
}
}
entry.second.erase(std::remove_if(
entry.second.begin(), entry.second.end(),
[currentTimestamp](const OriginalRecord *record) {
return record->getTimestamp() <= currentTimestamp;
}), entry.second.end());
}
for (auto &entry : batchesToCollect) {
ctx->collect(entry.second.first);
ctx->collect(entry.second.second);
}
for (auto it = recordsToCollect->begin(); it != recordsToCollect->end();) {
if (it->second.empty()) {
it = recordsToCollect->erase(it);
} else {
++it;
}
}
auto elapsed = std::chrono::steady_clock::now() - loopStartTime;
auto millisToSleep = checkInterval - std::chrono::duration_cast<std::chrono::milliseconds>(elapsed).count();
if (millisToSleep > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(millisToSleep));
}
ctx->getCheckpointLock()->mutex.unlock();
}
auto elapsedTime = std::chrono::steady_clock::now() - startTime;
if (sleepTime > 0 && std::chrono::duration_cast<std::chrono::milliseconds>(elapsedTime).count() >= sleepTime) {
cancel();
}
}
}
void JoinSource::cancel()
{
running = false;
}
void JoinSource::generateRecordsForKey()
{
std::string key = std::to_string(currentSubtaskIndex) + "_" + std::to_string(currentKeyId);
if (static_cast<int>(key.size()) < recordKeySize) {
key += ("_" + getRandomAlphanumeric(recordKeySize - key.size()));
}
long leftRecords = minLeftRecordsPerKey + std::abs(getRandomLong()) % (maxLeftRecordsPerKey - minLeftRecordsPerKey + 1);
long rightRecords = minRightRecordsPerKey + std::abs(getRandomLong()) % (maxRightRecordsPerKey - minRightRecordsPerKey + 1);
std::vector<OriginalRecord *> records;
long baseTimestamp = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
for (long i = 0; i < leftRecords; ++i) {
OriginalRecord *record = new OriginalRecord();
record->setKey(key);
record->setLeft(true);
record->setLeftTotalCount(leftRecords);
record->setRightTotalCount(rightRecords);
record->setValue(getRandomAlphanumeric(recordValueSize));
record->setTimestamp(baseTimestamp + std::abs(getRandomLong()) % leftMaxDelay);
records.push_back(record);
}
for (long i = 0; i < rightRecords; ++i) {
OriginalRecord *record = new OriginalRecord();
record->setKey(key);
record->setLeft(false);
record->setLeftTotalCount(leftRecords);
record->setRightTotalCount(rightRecords);
record->setValue(getRandomAlphanumeric(recordValueSize));
record->setTimestamp(baseTimestamp + std::abs(getRandomLong()) % rightMaxDelay);
records.push_back(record);
}
std::sort(records.begin(), records.end(), [](OriginalRecord *a, OriginalRecord *b) { return *a < *b; });
long currentLeftId = 1;
long currentRightId = 1;
for (auto &record : records) {
if (record->isLeft()) {
record->setCurrentLeftId(currentLeftId);
currentLeftId++;
} else {
record->setCurrentRightId(currentRightId);
currentRightId++;
}
}
recordsToCollect->emplace(currentKeyId, records);
currentKeyId++;
if (currentKeyId == LONG_MAX) {
currentKeyId = 0;
}
}
std::unordered_map<long, std::vector<OriginalRecord *>> &JoinSource::getRecordsToCollect()
{
return *recordsToCollect;
}
void JoinSource::originalRecordToBatch(OriginalRecord *record, omnistream::VectorBatch *batch, int index)
{
std::string_view key = record->getKey();
std::string_view value = record->getValue();
static_cast<omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>> *>(batch->Get(0))->SetValue(index, key);
static_cast<omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>> *>(batch->Get(1))->SetValue(index, value);
batch->setTimestamp(index, record->getTimestamp());
batch->setRowKind(index, RowKind::INSERT);
}
omnistream::VectorBatch *JoinSource::createBatch(int size)
{
omnistream::VectorBatch *batch = new omnistream::VectorBatch(size);
batch->Append(new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(size));
batch->Append(new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(size));
return batch;
}