* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#pragma once
#include <nlohmann/json.hpp>
#include "UUID.h"
#include "runtime/executiongraph/JobIDPOD.h"
#include "runtime/state/StateBackend.h"
#include "runtime/execution/OmniEnvironment.h"
#include "RocksDBKeyedStateBackendBuilder.h"
#include "RocksDBMemoryControllerUtils.h"
#include "DefaultOperatorStateBackendBuilder.h"
using json = nlohmann::json;
enum class TernaryBoolean {
FALSE,
TRUE,
UNDEFINED
};
class EmbeddedRocksDBStateBackend : public StateBackend {
public:
explicit EmbeddedRocksDBStateBackend(TernaryBoolean boolean) {};
~EmbeddedRocksDBStateBackend() {};
fs::path getNextStoragePath()
{
size_t ni = nextDirectory + 1;
ni = ni >= initializedDbBasePaths.size() ? 0 : ni;
nextDirectory = ni;
return initializedDbBasePaths[ni];
}
template <typename K>
AbstractKeyedStateBackend<K>* createKeyedStateBackend(
omnistream::EnvironmentV2* env,
std::string operatorIdentifier,
std::set<std::shared_ptr<KeyedStateHandle>> stateHandles,
KeyGroupRange* keyGroupRange,
TypeSerializer* keySerializer,
int numberOfKeyGroups,
int alternativeIdx)
{
std::string tempDir = env->taskConfiguration().getTmpWorkingDirectory().string();
std::string fileCompatibleIdentifier = operatorIdentifier;
std::replace_if(
fileCompatibleIdentifier.begin(),
fileCompatibleIdentifier.end(),
[](char c) { return !std::isalnum(c) && c != '-'; },
'_');
lazyInitializeForJob(env, fileCompatibleIdentifier);
fs::path instanceBasePath =
getNextStoragePath() / ("job_" + jobID.AbstractIDPOD::toString() + "_op_" + fileCompatibleIdentifier +
"_uuid_" + UUID::randomUUID().ToString());
auto localRecoveryConfig = env->getTaskStateManager()->createLocalRecoveryConfig();
auto bridge = env->getTaskStateManager()->getTaskStateManagerBridge();
auto omniTaskBridge = env->getTaskStateManager()->getOmniTaskBridge();
auto operatorIdStr = env->taskConfiguration().getStreamConfigPOD().getOperatorDescription().getOperatorId();
std::string upperHex = operatorIdStr.substr(0, 16);
std::string lowerHex = operatorIdStr.substr(16);
uint64_t uLower = std::stoull(lowerHex, nullptr, 16);
uint64_t uUpper = std::stoull(upperHex, nullptr, 16);
auto operatorId = std::make_shared<OperatorID>(uUpper, uLower);
auto sharedResources = RocksDBMemoryControllerUtils::allocateRocksDBSharedResources(env->taskConfiguration());
auto resourceContainer = std::make_shared<RocksDBResourceContainer>(sharedResources, instanceBasePath, false);
std::vector<std::shared_ptr<KeyedStateHandle>> stateVec(stateHandles.begin(), stateHandles.end());
auto priorityQueueStateType = env->taskConfiguration().getPriorityQueueStateType() == "ROCKSDB"
? RocksDBKeyedStateBackendBuilder<K>::PriorityQueueStateType::ROCKSDB
: RocksDBKeyedStateBackendBuilder<K>::PriorityQueueStateType::HEAP;
RocksDBKeyedStateBackendBuilder<K> builder(
operatorIdentifier,
instanceBasePath,
resourceContainer,
keySerializer,
numberOfKeyGroups,
keyGroupRange,
localRecoveryConfig,
priorityQueueStateType,
stateVec,
bridge,
omniTaskBridge,
operatorId,
alternativeIdx);
bool incrementalCheckpointing = enableIncrementalCheckpointing == TernaryBoolean::TRUE ? true : false;
builder.setEnableIncrementalCheckpointing(incrementalCheckpointing)
.setNumberOfTransferringThreads(numberOfTransferThreads)
.setWriteBatchSize(writeBatchSize);
return builder.build();
}
OperatorStateBackend* createOperatorStateBackend(
omnistream::EnvironmentV2* env,
std::string operatorIdentifier,
std::set<std::shared_ptr<OperatorStateHandle>> stateHandles)
{
std::vector<std::shared_ptr<OperatorStateHandle>> stateVector(stateHandles.begin(), stateHandles.end());
auto bridge = env->getTaskStateManager()->getTaskStateManagerBridge();
auto omniTaskBridge = env->getTaskStateManager()->getOmniTaskBridge();
const bool asynchronousSnapshots = true;
DefaultOperatorStateBackendBuilder builder(
asynchronousSnapshots, operatorIdentifier, stateVector, bridge, omniTaskBridge);
return builder.build();
}
explicit EmbeddedRocksDBStateBackend(TaskInformationPOD taskConfiguration)
: enableIncrementalCheckpointing(TernaryBoolean::UNDEFINED),
numberOfTransferThreads(UNDEFINED_NUMBER_OF_TRANSFER_THREADS),
nextDirectory(0),
isInitialized(false),
writeBatchSize(UNDEFINED_WRITE_BATCH_SIZE),
overlapFractionThreshold(UNDEFINED_OVERLAP_FRACTION_THRESHOLD)
{
enableIncrementalCheckpointing = taskConfiguration.getCheckpointConfig().getIncrementalCheckpoints()
? TernaryBoolean::TRUE
: TernaryBoolean::FALSE;
configureStoragePaths(taskConfiguration.getRocksdbStorePaths());
configureOtherParameters(taskConfiguration);
}
void handleDirectories()
{
std::vector<fs::path> validDirs;
std::string errorMessage;
for (const auto& dir : localRocksDbDirectories) {
fs::path testDir = dir / UUID::randomUUID().ToString();
try {
if (!fs::create_directories(testDir)) {
std::string msg =
"Local DB files directory '" + dir.string() + "' does not exist and cannot be created.";
errorMessage += msg + "\n";
} else {
validDirs.push_back(dir);
}
fs::remove_all(testDir);
} catch (const fs::filesystem_error& e) {
std::string msg = "Failed to create test directory in '" + dir.string() + "': " + e.what();
errorMessage += msg + "\n";
}
}
if (validDirs.empty()) {
throw std::runtime_error("No valid local storage directories available. " + errorMessage);
} else {
initializedDbBasePaths = std::move(validDirs);
}
}
void lazyInitializeForJob(omnistream::EnvironmentV2* env, const std::string& operatorIdentifier)
{
if (isInitialized) {
return;
}
jobID = env->getTaskStateManager()->getJobId();
if (localRocksDbDirectories.empty()) {
initializedDbBasePaths = {env->taskConfiguration().getTmpWorkingDirectory()};
} else {
handleDirectories();
}
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dist(0, initializedDbBasePaths.size() - 1);
nextDirectory = dist(gen);
isInitialized = true;
}
private:
void setDbStoragePaths(const std::vector<std::string>& paths)
{
if (paths.empty()) {
return;
}
std::vector<fs::path> pp;
pp.reserve(paths.size());
for (const auto& rawPath : paths) {
if (rawPath.empty()) {
throw std::invalid_argument("null path");
}
std::string processedPath = rawPath;
const std::string filePrefix = "file://";
if (rawPath.find(filePrefix) == 0) {
processedPath = rawPath.substr(filePrefix.length());
while (!processedPath.empty() && processedPath[0] == '/' && processedPath[1] == '/') {
processedPath = processedPath.substr(1);
}
if (!processedPath.empty() && processedPath[0] != '/') {
processedPath.insert(0, 1, '/');
}
if (processedPath.empty()) {
throw std::invalid_argument("Invalid file URI: " + rawPath);
}
} else if (rawPath.find("://") != std::string::npos) {
throw std::invalid_argument("Path " + rawPath + " has a non-local scheme");
}
fs::path pathObj(processedPath);
if (!pathObj.is_absolute()) {
throw std::invalid_argument("Relative paths are not supported: " + processedPath);
}
pp.emplace_back(std::move(pathObj));
}
localRocksDbDirectories = std::move(pp);
}
void configureStoragePaths(std::vector<std::string> rocksDbDirectories)
{
setDbStoragePaths(rocksDbDirectories);
}
void configureOtherParameters(TaskInformationPOD taskConfiguration)
{
numberOfTransferThreads = taskConfiguration.getNumberOfTransferThreads();
if (numberOfTransferThreads <= 0) {
THROW_LOGIC_EXCEPTION("Invalid number of transfer threads");
}
writeBatchSize = 2 * 1024 * 1024;
overlapFractionThreshold = 0.0;
}
static const long serialVersionUID = 1L;
static constexpr int ROCKSDB_LIB_LOADING_ATTEMPTS = 3;
static std::atomic<bool> rocksDbInitialized;
static constexpr int UNDEFINED_NUMBER_OF_TRANSFER_THREADS = -1;
static constexpr long UNDEFINED_WRITE_BATCH_SIZE = -1;
static constexpr double UNDEFINED_OVERLAP_FRACTION_THRESHOLD = -1;
std::vector<fs::path> localRocksDbDirectories;
TernaryBoolean enableIncrementalCheckpointing;
int numberOfTransferThreads;
std::vector<fs::path> initializedDbBasePaths;
JobIDPOD jobID;
int nextDirectory;
bool isInitialized;
long writeBatchSize;
double overlapFractionThreshold;
std::once_flag rocksdb_init_flag_;
bool rocksDbInitialized_ = false;
};