* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_OPERATORSTATERESTOREOPERATION_H
#define OMNISTREAM_OPERATORSTATERESTOREOPERATION_H
#include <unordered_map>
#include <set>
#include <vector>
#include <string>
#include <memory>
#include <stdexcept>
#include <limits>
#include <utility>
#include <nlohmann/json.hpp>
#include "core/include/common.h"
#include "PartitionableListState.h"
#include "RegisteredBroadcastStateBackendMetaInfo.h"
#include "state/bridge/OmniTaskBridge.h"
#include "OperatorStreamStateHandle.h"
#include "runtime/checkpoint/TaskStateSnapshotSerializer.h"
#include "RegisteredOperatorStateBackendMetaInfo.h"
#include "core/memory/DataInputDeserializer.h"
#include "state/memory/ByteStreamStateHandle.h"
class OperatorStateRestoreOperation {
public:
OperatorStateRestoreOperation(
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredOperatorStates,
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredBroadcastStates,
const std::vector<std::shared_ptr<OperatorStateHandle>>& stateHandles,
std::shared_ptr<OmniTaskBridge> omniTaskBridge)
: registeredOperatorStates_(registeredOperatorStates),
registeredBroadcastStates_(registeredBroadcastStates),
stateHandles_(stateHandles),
omniTaskBridge_(omniTaskBridge) {}
~OperatorStateRestoreOperation() = default;
void restore()
{
if (stateHandles_.empty()) {
return;
}
for (auto& stateHandle : stateHandles_) {
if (stateHandle == nullptr) {
continue;
}
auto streamStateHandle = std::dynamic_pointer_cast<OperatorStreamStateHandle>(stateHandle);
if (streamStateHandle) {
auto stateNameToPartitionOffsets = stateHandle->getStateNameToPartitionOffsets();
auto delegate = stateHandle->getDelegateStateHandle();
auto json = TaskStateSnapshotSerializer::parseOperatorStreamStateHandle(streamStateHandle);
std::string handleJson = to_string(json);
auto stateMetaInfoSnapshots = omniTaskBridge_->readOperatorMetaData(handleJson);
std::vector<uint8_t> stateData;
bool stateDataLoaded = false;
for (auto& snapshot : stateMetaInfoSnapshots) {
const std::string& stateName = snapshot.getName();
if (stateNameToPartitionOffsets.find(stateName) == stateNameToPartitionOffsets.end()) {
continue;
}
if (!isSupportedRestoredState(stateName)) {
continue;
}
auto metaInfo = std::make_shared<RegisteredOperatorStateBackendMetaInfo>(snapshot);
if (!stateDataLoaded) {
stateData = readOperatorStateData(delegate, handleJson);
stateDataLoaded = true;
}
deserializeOperatorStateValues(metaInfo, stateData, stateNameToPartitionOffsets);
}
}
}
}
void deserializeOperatorStateValues(
std::shared_ptr<RegisteredOperatorStateBackendMetaInfo> metaInfo,
const std::vector<uint8_t>& stateData,
std::unordered_map<std::string, OperatorStateHandle::StateMetaInfo> stateNameToPartitionOffsets)
{
if (stateData.size() > static_cast<size_t>(std::numeric_limits<int>::max())) {
INFO_RELEASE("Error: Operator state handle is too large to deserialize in memory: " << stateData.size());
THROW_LOGIC_EXCEPTION("Operator state handle is too large to deserialize in memory: "
<< stateData.size())
}
DataInputDeserializer in(stateData.empty() ? nullptr : stateData.data(),
static_cast<int>(stateData.size()), 0);
TypeSerializer *serializer = metaInfo->getStateSerializer();
if (serializer == nullptr) {
return;
}
for (auto& entry : stateNameToPartitionOffsets) {
auto name = entry.first;
if (name != metaInfo->getName()) {
continue;
}
auto offsets = entry.second.getOffsets();
if (typeByteStateNames.find(name) != typeByteStateNames.end()) {
auto listState = getOrCreateOperatorListState<std::vector<uint8_t>>(name, metaInfo);
for (auto& offset : offsets) {
validateOffset(name, offset, stateData.size());
in.setPosition(static_cast<size_t>(offset));
auto rawValue = static_cast<std::vector<uint8_t>*>(serializer->deserialize(in));
if (rawValue == nullptr) {
INFO_RELEASE("ERROR: Failed to deserialize operator state: " << name << ", offset="
<< std::to_string(offset));
throw std::runtime_error(
"Failed to deserialize operator state: " + name + ", offset=" + std::to_string(offset));
}
std::unique_ptr<std::vector<uint8_t>> value(rawValue);
listState->add(*value);
}
}
if (typeLongStateNames.find(name) != typeLongStateNames.end()) {
auto listState = getOrCreateOperatorListState<long>(name, metaInfo);
for (auto& offset : offsets) {
validateOffset(name, offset, stateData.size());
in.setPosition(static_cast<size_t>(offset));
auto rawValue = static_cast<long*>(serializer->deserialize(in));
if (rawValue == nullptr) {
INFO_RELEASE("ERROR: Failed to deserialize operator state: " << name << ", offset="
<< std::to_string(offset));
throw std::runtime_error(
"Failed to deserialize operator state: " + name + ", offset=" + std::to_string(offset));
}
std::unique_ptr<long> value(rawValue);
listState->add(*value);
}
}
}
}
private:
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredOperatorStates_;
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<State>>> registeredBroadcastStates_;
std::vector<std::shared_ptr<OperatorStateHandle>> stateHandles_;
std::shared_ptr<OmniTaskBridge> omniTaskBridge_;
static constexpr size_t READ_CHUNK_SIZE = 4096;
std::vector<uint8_t> readOperatorStateData(
const std::shared_ptr<StreamStateHandle>& stateHandle,
const std::string& handleJson)
{
if (stateHandle == nullptr) {
INFO_RELEASE("Error: Operator state delegate handle is null.");
THROW_LOGIC_EXCEPTION("Operator state delegate handle is null.")
}
auto inMemoryBytes = stateHandle->AsBytesIfInMemory();
if (inMemoryBytes.has_value()) {
return std::move(inMemoryBytes.value());
}
if (omniTaskBridge_ == nullptr) {
INFO_RELEASE("Error: Cannot restore remote operator state without OmniTaskBridge.");
THROW_LOGIC_EXCEPTION("Cannot restore remote operator state without OmniTaskBridge.")
}
jobject inputStream = nullptr;
try {
long expectedStateSize = stateHandle->GetStateSize();
if (expectedStateSize < 0) {
INFO_RELEASE("Error: Operator state handle has negative state size: " << expectedStateSize);
THROW_LOGIC_EXCEPTION("Operator state handle has negative state size: " << expectedStateSize)
}
if (expectedStateSize == 0) {
return {};
}
inputStream = omniTaskBridge_->getSavepointInputStream(handleJson);
if (inputStream == nullptr) {
INFO_RELEASE("Error: Failed to open operator state input stream through OmniTaskBridge.");
THROW_LOGIC_EXCEPTION("Failed to open operator state input stream through OmniTaskBridge.")
}
std::vector<uint8_t> data;
std::vector<uint8_t> chunk(READ_CHUNK_SIZE);
while (data.size() < static_cast<size_t>(expectedStateSize)) {
size_t remaining = static_cast<size_t>(expectedStateSize) - data.size();
size_t readLength = remaining < READ_CHUNK_SIZE ? remaining : READ_CHUNK_SIZE;
int read = omniTaskBridge_->ReadSavepointInputStream(inputStream,
reinterpret_cast<int8_t *>(chunk.data()), 0, readLength);
if (read < 0) {
INFO_RELEASE("Error: Operator state input stream ended before expected size: expected="
<< expectedStateSize << ", actual=" << data.size() << ", read=" << read);
THROW_LOGIC_EXCEPTION("Operator state input stream ended before expected size: expected="
<< expectedStateSize << ", actual=" << data.size() << ", read=" << read)
}
if (read == 0) {
INFO_RELEASE("Error: Operator state input stream returned zero bytes before expected size: expected="
<< expectedStateSize << ", actual=" << data.size());
THROW_LOGIC_EXCEPTION("Operator state input stream returned zero bytes before expected size: expected="
<< expectedStateSize << ", actual=" << data.size())
}
data.insert(data.end(), chunk.begin(), chunk.begin() + read);
}
jobject inputStreamToClose = inputStream;
inputStream = nullptr;
omniTaskBridge_->closeSavepointInputStream(inputStreamToClose);
return data;
} catch (...) {
if (inputStream != nullptr) {
jobject inputStreamToClose = inputStream;
inputStream = nullptr;
try {
omniTaskBridge_->closeSavepointInputStream(inputStreamToClose);
} catch (const std::exception& e) {
INFO_RELEASE("Error: Failed to close operator state input stream after restore failure: "
<< e.what());
} catch (...) {
INFO_RELEASE("Error: Failed to close operator state input stream after restore failure.");
}
}
throw;
}
}
static void validateOffset(const std::string& stateName, long offset, size_t stateSize)
{
if (offset < 0 || static_cast<size_t>(offset) >= stateSize) {
INFO_RELEASE("Error: Invalid operator state offset for state " << stateName
<< ": offset=" << offset << ", stateSize=" << stateSize);
THROW_LOGIC_EXCEPTION("Invalid operator state offset for state " << stateName
<< ": offset=" << offset << ", stateSize=" << stateSize)
}
}
static bool isSupportedRestoredState(const std::string& stateName)
{
return typeByteStateNames.find(stateName) != typeByteStateNames.end()
|| typeLongStateNames.find(stateName) != typeLongStateNames.end();
}
template <typename T>
std::shared_ptr<PartitionableListState<T>> getOrCreateOperatorListState(
const std::string& stateName,
const std::shared_ptr<RegisteredOperatorStateBackendMetaInfo>& metaInfo)
{
auto existing = registeredOperatorStates_->find(stateName);
if (existing != registeredOperatorStates_->end()) {
auto listState = std::dynamic_pointer_cast<PartitionableListState<T>>(existing->second);
if (listState == nullptr) {
INFO_RELEASE("Error: Restored operator state type mismatch for state: " << stateName);
throw std::runtime_error("Restored operator state type mismatch for state: " + stateName);
}
validateOperatorStateMetaInfo(stateName, listState->getStateMetaInfo(), metaInfo);
return listState;
}
auto listState = std::make_shared<PartitionableListState<T>>(metaInfo);
registeredOperatorStates_->emplace(stateName, listState);
return listState;
}
static void validateOperatorStateMetaInfo(
const std::string& stateName,
const std::shared_ptr<RegisteredOperatorStateBackendMetaInfo>& existingMetaInfo,
const std::shared_ptr<RegisteredOperatorStateBackendMetaInfo>& restoredMetaInfo)
{
if (existingMetaInfo == nullptr || restoredMetaInfo == nullptr) {
INFO_RELEASE("ERROR: Restored operator state meta info is null for state: " << stateName);
throw std::runtime_error("Restored operator state meta info is null for state: " + stateName);
}
if (existingMetaInfo->getName() != restoredMetaInfo->getName()
|| existingMetaInfo->getAssignmentMode() != restoredMetaInfo->getAssignmentMode()) {
INFO_RELEASE("ERROR: Restored operator state meta info mismatch for state: " << stateName);
throw std::runtime_error("Restored operator state meta info mismatch for state: " + stateName);
}
TypeSerializer* existingSerializer = existingMetaInfo->getStateSerializer();
TypeSerializer* restoredSerializer = restoredMetaInfo->getStateSerializer();
if (existingSerializer == nullptr || restoredSerializer == nullptr) {
INFO_RELEASE("ERROR: Restored operator state serializer is null for state: " << stateName);
throw std::runtime_error("Restored operator state serializer is null for state: " + stateName);
}
if (existingSerializer->getBackendId() != restoredSerializer->getBackendId()) {
INFO_RELEASE("ERROR: Restored operator state serializer mismatch for state: " << stateName);
throw std::runtime_error("Restored operator state serializer mismatch for state: " + stateName);
}
}
inline static std::set<std::string> typeByteStateNames = {
"SourceReaderState",
"writer_raw_states",
"streaming_committer_raw_states"
};
inline static std::set<std::string> typeLongStateNames = {
"watermark",
"elements-count-state"
};
};
#endif