* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_ROCKSDBINCREMENTALRESTOREOPERATION_H
#define OMNISTREAM_ROCKSDBINCREMENTALRESTOREOPERATION_H
#include "core/fs/CloseableRegistry.h"
#include "core/memory/DataInputDeserializer.h"
#include "runtime/state/StreamStateHandle.h"
#include "runtime/state/KeyedStateHandle.h"
#include "runtime/state/IncrementalKeyedStateHandle.h"
#include "runtime/state/KeyGroupRange.h"
#include "runtime/state/UUID.h"
#include "runtime/state/metainfo/StateMetaInfoSnapshot.h"
#include "runtime/state/IncrementalRemoteKeyedStateHandle.h"
#include "runtime/state/IncrementalLocalKeyedStateHandle.h"
#include "runtime/state/rocksdb/RocksDbHandle.h"
#include "runtime/state/RocksDbKvStateInfo.h"
#include "runtime/state/rocksdb/RocksDbOperationUtils.h"
#include "runtime/state/RocksDBIncrementalCheckpointUtils.h"
#include "runtime/state/restore/RocksDBRestoreOperation.h"
#include "runtime/state/restore/RocksDBRestoreResult.h"
#include "runtime/state/rocksdb/RocksDBStateDownloader.h"
#include "runtime/state/CompositeKeySerializationUtils.h"
#include "runtime/state/RocksDBWriteBatchWrapper.h"
template <typename K>
class RocksDBIncrementalRestoreOperation : public RocksDBRestoreOperation {
public:
RocksDBIncrementalRestoreOperation(
std::string operatorIdentifier,
KeyGroupRange* keyGroupRange,
int keyGroupPrefixBytes,
int numberOfTransferringThreads,
std::unordered_map<std::string, std::shared_ptr<RocksDbKvStateInfo>> *kvStateInformation,
std::shared_ptr<TypeSerializer> keySerializer,
std::filesystem::path& instanceBasePath,
std::filesystem::path& instanceRocksDBPath,
std::shared_ptr<rocksdb::DBOptions> dbOptions,
std::function<rocksdb::ColumnFamilyOptions(const std::string&)> columnFamilyOptionsFactory,
std::vector<std::shared_ptr<KeyedStateHandle>> restoreStateHandles,
long writeBatchSize,
std::shared_ptr<OmniTaskBridge> omniTaskBridge,
std::shared_ptr<OperatorID> operatorId,
int alternativeIdx)
: operatorIdentifier_(std::move(operatorIdentifier)),
restoredSstFiles_(),
lastCompletedCheckpointId_(-1L),
backendUID_(UUID::randomUUID()),
writeBatchSize_(writeBatchSize),
restoreStateHandles_(std::move(restoreStateHandles)),
keyGroupRange_(keyGroupRange),
keySerializer_(std::move(keySerializer)),
instanceBasePath_(instanceBasePath),
numberOfTransferringThreads_(numberOfTransferringThreads),
keyGroupPrefixBytes_(keyGroupPrefixBytes),
overlapFractionThreshold_(0),
omniTaskBridge_(omniTaskBridge),
operatorId_(operatorId),
alternativeIdx_(alternativeIdx)
{
this->rocksHandle_ = std::make_unique<RocksDbHandle>(
kvStateInformation,
instanceRocksDBPath,
dbOptions,
columnFamilyOptionsFactory);
}
~RocksDBIncrementalRestoreOperation() override = default;
std::shared_ptr<RocksDBRestoreResult> restore() override
{
if (restoreStateHandles_.empty()) {
return nullptr;
}
std::shared_ptr<KeyedStateHandle> theFirstStateHandle = restoreStateHandles_.front();
bool isRescaling = false;
if (restoreStateHandles_.size() > 1) {
isRescaling = true;
} else {
const auto &firstKeyGroupRange = theFirstStateHandle->GetKeyGroupRange();
if (!(firstKeyGroupRange == *keyGroupRange_)) {
isRescaling = true;
}
}
try {
if (isRescaling) {
restoreWithRescaling(restoreStateHandles_);
} else {
restoreWithoutRescaling(theFirstStateHandle);
}
} catch (const std::exception &e) {
throw std::runtime_error("Restore failed: " + std::string(e.what()));
} catch (...) {
throw std::runtime_error("Restore failed with unknown exception");
}
return std::make_shared<RocksDBRestoreResult>(
rocksHandle_->getDb(),
rocksHandle_->getDefaultColumnFamilyHandle(),
lastCompletedCheckpointId_,
backendUID_,
restoredSstFiles_
);
}
private:
std::string operatorIdentifier_;
std::map<long, std::vector<IncrementalKeyedStateHandle::HandleAndLocalPath>> restoredSstFiles_;
long lastCompletedCheckpointId_;
UUID backendUID_;
long writeBatchSize_;
std::vector<std::shared_ptr<KeyedStateHandle>> restoreStateHandles_;
std::unique_ptr<RocksDbHandle> rocksHandle_;
KeyGroupRange* keyGroupRange_;
std::shared_ptr<TypeSerializer> keySerializer_;
std::filesystem::path instanceBasePath_;
int numberOfTransferringThreads_;
int keyGroupPrefixBytes_;
double overlapFractionThreshold_;
JNIEnv* env_;
std::shared_ptr<OmniTaskBridge> omniTaskBridge_;
std::shared_ptr<OperatorID> operatorId_;
int alternativeIdx_;
void restoreWithoutRescaling(std::shared_ptr<KeyedStateHandle> keyedStateHandle)
{
auto remoteHandle = std::dynamic_pointer_cast<IncrementalRemoteKeyedStateHandle>(keyedStateHandle);
if (remoteHandle != nullptr) {
restorePreviousIncrementalFilesStatus(remoteHandle);
restoreFromRemoteState(remoteHandle);
} else {
auto localHandle = std::dynamic_pointer_cast<IncrementalLocalKeyedStateHandle>(keyedStateHandle);
if (localHandle != nullptr) {
restorePreviousIncrementalFilesStatus(localHandle);
restoreFromLocalState(localHandle);
} else {
throw std::runtime_error("restoreWithoutRescaling error");
}
}
}
void restorePreviousIncrementalFilesStatus(
std::shared_ptr<IncrementalKeyedStateHandle> localKeyedStateHandle)
{
backendUID_ = localKeyedStateHandle->GetBackendIdentifier();
restoredSstFiles_.emplace(localKeyedStateHandle->GetCheckpointId(),
localKeyedStateHandle->GetSharedStateHandles());
lastCompletedCheckpointId_ = localKeyedStateHandle->GetCheckpointId();
}
void restoreFromRemoteState(std::shared_ptr<IncrementalRemoteKeyedStateHandle> stateHandle)
{
std::filesystem::path tmpRestoreInstancePath =
std::filesystem::absolute(instanceBasePath_) / UUID::randomUUID().ToString();
auto restoreStateHandle =
transferRemoteStateToLocalDirectory(omniTaskBridge_, tmpRestoreInstancePath, stateHandle);
restoreFromLocalState(restoreStateHandle);
}
void restoreFromLocalState(std::shared_ptr<IncrementalLocalKeyedStateHandle> localKeyedStateHandle)
{
auto serializerStr = TaskStateSnapshotSerializer::parseIncrementalKeyedStateHandle(localKeyedStateHandle);
std::vector<StateMetaInfoSnapshot> stateMetaInfoSnapshots =
omniTaskBridge_->readMetaData(to_string(serializerStr));
Path path = localKeyedStateHandle->GetDirectoryStateHandle()->getDirectory();
std::filesystem::path restoreSourcePath = path.toString();
LOG("localKeyedStateHandle path!" << restoreSourcePath);
rocksHandle_->openDB(createColumnFamilyDescriptors(stateMetaInfoSnapshots, false),
stateMetaInfoSnapshots,
restoreSourcePath);
}
void cleanUpPathQuietly(const std::filesystem::path& path)
{
try {
if (!std::filesystem::exists(path)) {
return;
}
std::filesystem::remove_all(path);
} catch (const std::exception& ex) {
}
}
std::shared_ptr<IncrementalLocalKeyedStateHandle> transferRemoteStateToLocalDirectory(
std::shared_ptr<omnistream::OmniTaskBridge> omniTaskBridge,
fs::path temporaryRestoreInstancePath,
std::shared_ptr<IncrementalRemoteKeyedStateHandle> restoreStateHandle)
{
std::unique_ptr<RocksDBStateDownloader> rocksDBStateDownloader =
std::make_unique<RocksDBStateDownloader>(numberOfTransferringThreads_);
rocksDBStateDownloader->transferAllStateDataToDirectory(*restoreStateHandle,
temporaryRestoreInstancePath,
omniTaskBridge);
Path path(temporaryRestoreInstancePath.string());
return std::make_shared<IncrementalLocalKeyedStateHandle>(
restoreStateHandle->GetBackendIdentifier(),
restoreStateHandle->GetCheckpointId(),
new DirectoryStateHandle(path, 0),
restoreStateHandle->GetKeyGroupRange(),
restoreStateHandle->GetMetaDataStateHandle(),
restoreStateHandle->GetSharedState()
);
}
* Recovery from multi incremental states with rescaling. For rescaling, this method creates a
* temporary RocksDB instance for a key-groups shard. All contents from the temporary instance
* are copied into the real restore instance and then the temporary instance is discarded.
*/
void restoreWithRescaling(std::vector<std::shared_ptr<KeyedStateHandle>> restoreStateHandles)
{
std::shared_ptr<KeyedStateHandle> initialPtr = selectInitialStateHandle(restoreStateHandles);
if (initialPtr) {
initDBWithRescaling(initialPtr);
} else {
rocksHandle_->openDB();
}
std::vector<uint8_t> startKeyGroupPrefixBytes(keyGroupPrefixBytes_);
CompositeKeySerializationUtils::serializeKeyGroup(keyGroupRange_->getStartKeyGroup(), startKeyGroupPrefixBytes);
std::vector<uint8_t> stopKeyGroupPrefixBytes(keyGroupPrefixBytes_);
CompositeKeySerializationUtils::serializeKeyGroup(keyGroupRange_->getEndKeyGroup() + 1,
stopKeyGroupPrefixBytes);
for (const auto& rawStateHandle : restoreStateHandles) {
auto remoteHandle = std::dynamic_pointer_cast<IncrementalRemoteKeyedStateHandle>(rawStateHandle);
if (!remoteHandle) {
throw unexpectedStateHandleException(
typeid(IncrementalRemoteKeyedStateHandle),
typeid(*rawStateHandle)
);
}
fs::path temporaryRestoreInstancePath =
std::filesystem::absolute(instanceBasePath_) / UUID::randomUUID().ToString();
try {
auto tmpRestoreDBInfo = restoreDBInstanceFromStateHandle(remoteHandle, temporaryRestoreInstancePath);
RocksDBWriteBatchWrapper writeBatchWrapper(rocksHandle_->getDb(), writeBatchSize_);
std::vector<rocksdb::ColumnFamilyDescriptor> tmpColumnFamilyDescriptors =
tmpRestoreDBInfo->columnFamilyDescriptors_;
std::vector<rocksdb::ColumnFamilyHandle*> tmpColumnFamilyHandles =
tmpRestoreDBInfo->columnFamilyHandles_;
for (size_t i = 0; i < tmpColumnFamilyDescriptors.size(); ++i) {
rocksdb::ColumnFamilyHandle* tmpColumnFamilyHandle = tmpColumnFamilyHandles[i];
rocksdb::ColumnFamilyHandle* targetColumnFamilyHandle =
rocksHandle_->getOrRegisterStateColumnFamilyHandle(
nullptr, tmpRestoreDBInfo->stateMetaInfoSnapshots_[i])->columnFamilyHandle_;
auto iterator = RocksDbOperationUtils::getRocksIterator(
tmpRestoreDBInfo->db_, tmpColumnFamilyHandle, *tmpRestoreDBInfo->readOptions_);
migrateDataBetweenColumnFamilies(iterator, targetColumnFamilyHandle, startKeyGroupPrefixBytes,
stopKeyGroupPrefixBytes, writeBatchWrapper);
}
} catch (...) {
cleanUpPathQuietly(temporaryRestoreInstancePath);
throw;
}
cleanUpPathQuietly(temporaryRestoreInstancePath);
}
}
std::shared_ptr<KeyedStateHandle> selectInitialStateHandle(
std::vector<std::shared_ptr<KeyedStateHandle>>& restoreStateHandles)
{
auto initialHandle = RocksDBIncrementalCheckpointUtils::chooseTheBestStateHandleForInitial(
restoreStateHandles, *keyGroupRange_, overlapFractionThreshold_);
std::shared_ptr<KeyedStateHandle> initialPtr;
if (initialHandle != nullptr) {
auto it = std::find(restoreStateHandles.begin(), restoreStateHandles.end(), initialHandle);
if (it != restoreStateHandles.end()) {
initialPtr = *it;
restoreStateHandles.erase(it);
}
}
return initialPtr;
}
void migrateDataBetweenColumnFamilies(
std::unique_ptr<RocksIteratorWrapper>& iterator,
rocksdb::ColumnFamilyHandle* targetCF,
const std::vector<uint8_t>& startKey,
const std::vector<uint8_t>& stopKey,
RocksDBWriteBatchWrapper& writeBatch)
{
iterator->seek(startKey);
while (iterator->isValid()) {
std::string key = iterator->key();
std::string stopKeyGroupPrefixBytesSliceStr(stopKey.begin(), stopKey.end());
if (RocksDBIncrementalCheckpointUtils::beforeThePrefixBytes(
key, stopKeyGroupPrefixBytesSliceStr)) {
writeBatch.Put(
targetCF,
key,
iterator->value()
);
} else {
break;
}
iterator->next();
}
}
void initDBWithRescaling(std::shared_ptr<KeyedStateHandle> initialPtr)
{
auto remoteHandle = std::dynamic_pointer_cast<IncrementalRemoteKeyedStateHandle>(initialPtr);
if (!remoteHandle) {
throw std::invalid_argument("Initial handle must be IncrementalRemoteKeyedStateHandle");
}
restoreFromRemoteState(remoteHandle);
try {
RocksDBIncrementalCheckpointUtils::clipDBWithKeyGroupRange(
rocksHandle_->getDb(),
rocksHandle_->getColumnFamilyHandles(),
*keyGroupRange_,
initialPtr->GetKeyGroupRange(),
keyGroupPrefixBytes_);
} catch (...) {
}
}
class RestoredDBInstance {
public:
RestoredDBInstance(
rocksdb::DB* db,
std::vector<rocksdb::ColumnFamilyHandle*> columnFamilyHandles,
std::vector<rocksdb::ColumnFamilyDescriptor> columnFamilyDescriptors,
std::vector<StateMetaInfoSnapshot> stateMetaInfoSnapshots
) : db_(db),
columnFamilyDescriptors_(std::move(columnFamilyDescriptors)),
stateMetaInfoSnapshots_(std::move(stateMetaInfoSnapshots)),
columnFamilyHandles_(columnFamilyHandles),
readOptions_(std::make_shared<rocksdb::ReadOptions>()) {
if (columnFamilyHandles_.empty()) {
throw std::invalid_argument("columnFamilyHandles cannot be empty");
}
defaultColumnFamilyHandle_ = columnFamilyHandles.front();
columnFamilyHandles.erase(columnFamilyHandles.begin());
columnFamilyHandles_ = columnFamilyHandles;
}
~RestoredDBInstance() {}
rocksdb::DB* db_;
rocksdb::ColumnFamilyHandle* defaultColumnFamilyHandle_;
std::vector<rocksdb::ColumnFamilyHandle*> columnFamilyHandles_;
std::vector<rocksdb::ColumnFamilyDescriptor> columnFamilyDescriptors_;
std::vector<StateMetaInfoSnapshot> stateMetaInfoSnapshots_;
std::shared_ptr<rocksdb::ReadOptions> readOptions_;
};
std::shared_ptr<RestoredDBInstance> restoreDBInstanceFromStateHandle(
std::shared_ptr<IncrementalRemoteKeyedStateHandle> restoreStateHandle,
fs::path temporaryRestoreInstancePath)
{
RocksDBStateDownloader rocksDbStateDownloader(numberOfTransferringThreads_);
rocksDbStateDownloader.transferAllStateDataToDirectory(*restoreStateHandle,
temporaryRestoreInstancePath,
omniTaskBridge_);
auto serializerStr = TaskStateSnapshotSerializer::parseIncrementalRemoteKeyedStateHandle(restoreStateHandle);
auto stateMetaInfoSnapshots = omniTaskBridge_->readMetaData(to_string(serializerStr));
std::vector<rocksdb::ColumnFamilyDescriptor> columnFamilyDescriptors =
createColumnFamilyDescriptors(stateMetaInfoSnapshots, false);
std::vector<rocksdb::ColumnFamilyHandle*> columnFamilyHandles;
columnFamilyHandles.reserve(stateMetaInfoSnapshots.size() + 1);
rocksdb::ColumnFamilyOptions columnFamilyOptions =
RocksDbOperationUtils::createColumnFamilyOptions(rocksHandle_->getColumnFamilyOptionsFactory(),
"default");
std::shared_ptr<rocksdb::DBOptions> dbOptions = rocksHandle_->getDbOptions();
rocksdb::DB* restoreDb = RocksDbOperationUtils::openDB(
temporaryRestoreInstancePath,
columnFamilyDescriptors,
columnFamilyHandles,
columnFamilyOptions,
*dbOptions);
return std::make_shared<RestoredDBInstance>(restoreDb, columnFamilyHandles,
columnFamilyDescriptors, stateMetaInfoSnapshots);
};
std::vector<rocksdb::ColumnFamilyDescriptor> createColumnFamilyDescriptors (
std::vector<StateMetaInfoSnapshot> stateMetaInfoSnapshots,
bool registerTtlCompactFilter)
{
std::vector<rocksdb::ColumnFamilyDescriptor> columnFamilyDescriptors;
columnFamilyDescriptors.reserve(stateMetaInfoSnapshots.size());
for (const auto& snapshot : stateMetaInfoSnapshots) {
auto metaInfoBase = RegisteredStateMetaInfoBase::fromMetaInfoSnapshot(snapshot);
if (!metaInfoBase) {
continue;
}
rocksdb::ColumnFamilyDescriptor columnFamilyDescriptor =
RocksDbOperationUtils::createColumnFamilyDescriptor(
std::move(metaInfoBase),
rocksHandle_->getColumnFamilyOptionsFactory());
columnFamilyDescriptors.push_back(std::move(columnFamilyDescriptor));
}
return columnFamilyDescriptors;
}
std::runtime_error unexpectedStateHandleException(
const std::type_info& expected, const std::type_info& actual)
{
return std::runtime_error(
"Unexpected state handle type: expected " +
std::string(expected.name()) + ", but got " + std::string(actual.name()));
}
};
#endif