* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_HEAPSINGLESTATEITERATOR_H
#define OMNISTREAM_HEAPSINGLESTATEITERATOR_H
#include <vector>
#include <cstdint>
#include <memory>
#include <type_traits>
#include "runtime/state/rocksdb/iterator/SingleStateIterator.h"
#include "runtime/state/CompositeKeySerializationUtils.h"
#include "core/memory/DataOutputSerializer.h"
#include "core/typeutils/TypeSerializer.h"
#include "core/typeutils/MapSerializer.h"
#include "core/typeutils/ListSerializer.h"
#include "core/typeutils/LongSerializer.h"
#include "basictypes/Object.h"
#include "StateTable.h"
#include "CopyOnWriteStateMap.h"
#include "table/utils/VectorBatchSerializationUtils.h"
#include "table/data/vectorbatch/VectorBatch.h"
#include "../../../core/include/common.h"
template <typename T>
struct IsEmhashMapPtr : std::false_type {};
template <typename UK, typename UV>
struct IsEmhashMapPtr<emhash7::HashMap<UK, UV>*> : std::true_type {};
template <typename T>
struct IsVectorPtr : std::false_type {};
template <typename V>
struct IsVectorPtr<std::vector<V>*> : std::true_type {};
* A SingleStateIterator that iterates over a Heap CopyOnWriteStateTable,
* serializing each entry into byte arrays compatible with the RocksDB key format:
* key = [keyGroupPrefix] + [serialized key] + [serialized namespace]
* value = [serialized state value]
*
* Entries are materialized eagerly during construction and iterated in
* key-group order (ascending) so that RocksStatesPerKeyGroupMergeIterator can
* merge them correctly without touching live state in the async phase.
*/
template <typename K, typename N, typename S>
class HeapSingleStateIterator : public SingleStateIterator {
public:
struct VbDataTag {};
HeapSingleStateIterator(StateTable<K, N, S>* stateTable, int kvStateId, int keyGroupPrefixBytes)
: stateTable_(stateTable),
kvStateId_(kvStateId),
keyGroupPrefixBytes_(keyGroupPrefixBytes)
{
collectAndSerializeEntries();
currentIndex_ = 0;
valid_ = !entries_.empty();
refreshKeyGroup();
}
HeapSingleStateIterator(
StateTable<int, VoidNamespace, omnistream::VectorBatch*>* vbTable,
int kvStateId,
int keyGroupPrefixBytes,
VbDataTag)
: stateTable_(reinterpret_cast<StateTable<K, N, S>*>(vbTable)),
kvStateId_(kvStateId),
keyGroupPrefixBytes_(keyGroupPrefixBytes)
{
collectVbEntries();
currentIndex_ = 0;
valid_ = !entries_.empty();
refreshKeyGroup();
}
void next() override
{
if (valid_) {
currentIndex_++;
valid_ = (currentIndex_ < entries_.size());
refreshKeyGroup();
}
}
bool isValid() const override
{
return valid_;
}
ByteView key() const override
{
const auto& key = entries_[currentIndex_].serializedKey;
return ByteView::fromBuffer(key.data(), key.size());
}
ByteView value() const override
{
const auto& value = entries_[currentIndex_].serializedValue;
return ByteView::fromBuffer(value.data(), value.size());
}
int keyGroup() const override
{
return currentKeyGroup_;
}
int getKvStateId() const override
{
return kvStateId_;
}
size_t getEntryCount() const override
{
return entries_.size();
}
void close() override
{
entries_.clear();
valid_ = false;
}
private:
struct SerializedEntry {
std::vector<int8_t> serializedKey;
std::vector<int8_t> serializedValue;
};
StateTable<K, N, S>* stateTable_;
int kvStateId_;
int keyGroupPrefixBytes_;
std::vector<SerializedEntry> entries_;
size_t currentIndex_ = 0;
int currentKeyGroup_ = -1;
bool valid_ = false;
void refreshKeyGroup()
{
currentKeyGroup_ = -1;
if (!valid_ || currentIndex_ >= entries_.size()) {
return;
}
const auto& key = entries_[currentIndex_].serializedKey;
if (key.size() < static_cast<size_t>(keyGroupPrefixBytes_)) {
return;
}
int result = 0;
for (int i = 0; i < keyGroupPrefixBytes_; ++i) {
result <<= 8;
result |= static_cast<int>(static_cast<uint8_t>(key[i]));
}
currentKeyGroup_ = result;
}
void collectAndSerializeEntries()
{
auto* stateMaps = stateTable_->getState();
int keyGroupOffset = stateTable_->getKeyGroupOffset();
TypeSerializer* keySerializer = stateTable_->getKeySerializer();
TypeSerializer* namespaceSerializer = stateTable_->getNamespaceSerializer();
TypeSerializer* stateSerializer = stateTable_->getStateSerializer();
for (size_t i = 0; i < stateMaps->size(); i++) {
int keyGroup = keyGroupOffset + static_cast<int>(i);
auto* stateMap = (*stateMaps)[i];
if (stateMap == nullptr || stateMap->size() == 0) {
continue;
}
serializeStateMap(stateMap, keyGroup, keySerializer, namespaceSerializer, stateSerializer);
}
std::sort(entries_.begin(), entries_.end(), [this](const SerializedEntry& a, const SerializedEntry& b) -> bool {
for (int i = 0; i < keyGroupPrefixBytes_ && i < static_cast<int>(a.serializedKey.size()) &&
i < static_cast<int>(b.serializedKey.size());
i++) {
if (static_cast<uint8_t>(a.serializedKey[i]) != static_cast<uint8_t>(b.serializedKey[i])) {
return static_cast<uint8_t>(a.serializedKey[i]) < static_cast<uint8_t>(b.serializedKey[i]);
}
}
return false;
});
}
void collectVbEntries()
{
auto* stateMaps = stateTable_->getState();
int keyGroupOffset = stateTable_->getKeyGroupOffset();
for (size_t i = 0; i < stateMaps->size(); i++) {
int keyGroup = keyGroupOffset + static_cast<int>(i);
auto* stateMap = (*stateMaps)[i];
if (stateMap == nullptr || stateMap->size() == 0) {
continue;
}
auto* cowMap = dynamic_cast<omnistream::CopyOnWriteStateMap<K, N, S>*>(stateMap);
if (cowMap == nullptr) {
continue;
}
for (auto it = cowMap->begin(); it != cowMap->end(); ++it) {
SerializedEntry entry;
try {
entry.serializedKey = serializeVbKey(keyGroup, it->first, it->third);
entry.serializedValue = serializeVbValue(it->second);
} catch (const std::exception& e) {
INFO_RELEASE(
"Error:HeapSingleStateIterator: collectVbEntries EXCEPTION at keyGroup="
<< keyGroup << ", error=" << e.what());
throw;
}
entries_.push_back(std::move(entry));
}
}
std::sort(entries_.begin(), entries_.end(), [this](const SerializedEntry& a, const SerializedEntry& b) -> bool {
for (int i = 0; i < keyGroupPrefixBytes_ && i < static_cast<int>(a.serializedKey.size()) &&
i < static_cast<int>(b.serializedKey.size());
i++) {
if (static_cast<uint8_t>(a.serializedKey[i]) != static_cast<uint8_t>(b.serializedKey[i])) {
return static_cast<uint8_t>(a.serializedKey[i]) < static_cast<uint8_t>(b.serializedKey[i]);
}
}
return false;
});
}
std::vector<int8_t> serializeVbKey(int keyGroup, const int64_t& batchId, const VoidNamespace&)
{
DataOutputSerializer outputSerializer;
OutputBufferStatus outputBufferStatus;
outputSerializer.setBackendBuffer(&outputBufferStatus);
outputSerializer.writeByte(static_cast<uint32_t>(keyGroup));
LongSerializer longSerializer;
longSerializer.serialize(const_cast<int64_t*>(&batchId), outputSerializer);
std::vector<int8_t> result(outputSerializer.getPosition());
memcpy(result.data(), outputSerializer.getData(), outputSerializer.getPosition());
return result;
}
static std::vector<int8_t> serializeVbValue(omnistream::VectorBatch* vectorBatch)
{
if (vectorBatch == nullptr) {
return {};
}
int32_t batchSize = VectorBatchSerializationUtils::calculateVectorBatchSerializableSize(vectorBatch);
if (batchSize <= 0) {
return {};
}
uint8_t* buffer = new uint8_t[batchSize];
uint8_t* cursor = buffer;
VectorBatchSerializationUtils::serializeVectorBatch(vectorBatch, batchSize, cursor);
std::vector<int8_t> result(batchSize);
for (int32_t i = 0; i < batchSize; i++) {
result[i] = static_cast<int8_t>(buffer[i]);
}
delete[] buffer;
return result;
}
struct RawSnapshotEntry {
RawSnapshotEntry(const K& snapshotKey, const N& snapshotNamespace, const S& snapshotValue)
: key(snapshotKey),
nmspace(snapshotNamespace),
value(snapshotValue),
ownsRefs_(true)
{
retainObjectRef(key);
retainObjectRef(nmspace);
retainObjectRef(value);
}
RawSnapshotEntry(const RawSnapshotEntry&) = delete;
RawSnapshotEntry& operator=(const RawSnapshotEntry&) = delete;
RawSnapshotEntry(RawSnapshotEntry&& other) noexcept
: key(other.key),
nmspace(other.nmspace),
value(other.value),
ownsRefs_(other.ownsRefs_)
{
other.ownsRefs_ = false;
}
RawSnapshotEntry& operator=(RawSnapshotEntry&& other) noexcept
{
if (this != &other) {
releaseRefs();
key = other.key;
nmspace = other.nmspace;
value = other.value;
ownsRefs_ = other.ownsRefs_;
other.ownsRefs_ = false;
}
return *this;
}
~RawSnapshotEntry()
{
releaseRefs();
}
K key;
N nmspace;
S value;
private:
template <typename T>
static void retainObjectRef(const T& ptr)
{
if constexpr (std::is_same_v<std::decay_t<T>, Object*>) {
if (ptr != nullptr) {
ptr->getRefCount();
}
}
}
template <typename T>
static void releaseObjectRef(const T& ptr)
{
if constexpr (std::is_same_v<std::decay_t<T>, Object*>) {
if (ptr != nullptr) {
ptr->putRefCount();
}
}
}
void releaseRefs()
{
if (!ownsRefs_) {
return;
}
releaseObjectRef(key);
releaseObjectRef(nmspace);
releaseObjectRef(value);
ownsRefs_ = false;
}
bool ownsRefs_;
};
void serializeStateMap(
StateMap<K, N, S>* stateMap,
int keyGroup,
TypeSerializer* keySerializer,
TypeSerializer* namespaceSerializer,
TypeSerializer* stateSerializer)
{
auto* cowMap = dynamic_cast<omnistream::CopyOnWriteStateMap<K, N, S>*>(stateMap);
if (cowMap == nullptr) {
return;
}
std::vector<RawSnapshotEntry> snapshot;
snapshot.reserve(cowMap->size());
for (auto it = cowMap->begin(); it != cowMap->end(); ++it) {
snapshot.emplace_back(it->first, it->third, it->second);
}
int mapEntryCount = 0;
for (auto& raw : snapshot) {
SerializedEntry entry;
try {
entry.serializedKey = serializeKey(keyGroup, raw.key, raw.nmspace, keySerializer, namespaceSerializer);
entry.serializedValue = serializeValue(raw.value, stateSerializer);
} catch (const std::exception& e) {
INFO_RELEASE(
"Error:HeapSingleStateIterator: serializeStateMap EXCEPTION at keyGroup="
<< keyGroup << ", entryIndex=" << mapEntryCount << ", error=" << e.what());
throw;
}
entries_.push_back(std::move(entry));
mapEntryCount++;
}
}
std::vector<int8_t> serializeKey(
int keyGroup,
const K& key,
const N& nmspace,
TypeSerializer* keySerializer,
TypeSerializer* namespaceSerializer)
{
DataOutputSerializer outputSerializer;
OutputBufferStatus outputBufferStatus;
outputSerializer.setBackendBuffer(&outputBufferStatus);
if (keyGroupPrefixBytes_ == 1) {
outputSerializer.writeByte(static_cast<uint32_t>(keyGroup));
} else {
outputSerializer.writeByte(static_cast<uint32_t>((keyGroup >> 8) & 0xFF));
outputSerializer.writeByte(static_cast<uint32_t>(keyGroup & 0xFF));
}
if constexpr (std::is_pointer_v<K>) {
keySerializer->serialize(const_cast<K>(key), outputSerializer);
} else if constexpr (is_shared_ptr_v<K>) {
if (!key) {
THROW_LOGIC_EXCEPTION("Heap snapshot cannot serialize a null shared_ptr key");
}
keySerializer->serialize(key.get(), outputSerializer);
} else {
K mutableKey = key;
keySerializer->serialize(&mutableKey, outputSerializer);
}
if constexpr (std::is_pointer_v<N>) {
namespaceSerializer->serialize(const_cast<N>(nmspace), outputSerializer);
} else if constexpr (is_shared_ptr_v<N>) {
if (!nmspace) {
THROW_LOGIC_EXCEPTION("Heap snapshot cannot serialize a null shared_ptr key");
}
namespaceSerializer->serialize(nmspace.get(), outputSerializer);
} else {
N mutableNs = nmspace;
namespaceSerializer->serialize(&mutableNs, outputSerializer);
}
auto* data = outputSerializer.getData();
size_t len = outputSerializer.length();
std::vector<int8_t> result(len);
for (size_t i = 0; i < len; i++) {
result[i] = static_cast<int8_t>(data[i]);
}
return result;
}
* Serializes a single emhash7::HashMap entry-by-entry using the MapSerializer's
* sub-serializers. Format: [int size] [key + bool isNull + value per entry].
*
* For Object* types, uses serialize(Object*,...) since PojoSerializer's void* path is NOT_IMPL.
* For other pointer types (RowData*, etc.), uses serialize(void*,...).
* For value types (int, int64_t, etc.), uses serialize(void*,...) with address.
*/
template <typename UK, typename UV>
static void serializeEmhashMap(
const emhash7::HashMap<UK, UV>& map, TypeSerializer* keySer, TypeSerializer* valSer, DataOutputSerializer& out)
{
out.writeInt(static_cast<int>(map.size()));
int idx = 0;
for (const auto& pair : map) {
if constexpr (std::is_same_v<UK, Object*>) {
if (pair.first == nullptr) {
INFO_RELEASE("Error:serializeEmhashMap: WARNING null Object* key at index=" << idx);
}
keySer->serialize(const_cast<Object*>(pair.first), out);
} else if constexpr (std::is_pointer_v<UK>) {
keySer->serialize(const_cast<UK>(pair.first), out);
} else {
UK mk = pair.first;
keySer->serialize(&mk, out);
}
if constexpr (std::is_pointer_v<UV>) {
if (pair.second == nullptr) {
out.writeBoolean(true);
} else {
out.writeBoolean(false);
if constexpr (std::is_same_v<UV, Object*>) {
valSer->serialize(const_cast<Object*>(pair.second), out);
} else {
valSer->serialize(const_cast<UV>(pair.second), out);
}
}
} else {
out.writeBoolean(false);
UV mv = pair.second;
valSer->serialize(&mv, out);
}
idx++;
}
}
* Serializes a std::vector entry-by-entry using the ListSerializer's element serializer.
* Format matches ListSerializer::serialize(Object*,...): [int size] [elem_1] [elem_2] ...
*/
template <typename V>
static void serializeVector(const std::vector<V>& vec, TypeSerializer* elemSer, DataOutputSerializer& out)
{
out.writeInt(static_cast<int>(vec.size()));
for (const auto& elem : vec) {
if constexpr (std::is_pointer_v<V>) {
elemSer->serialize(const_cast<V>(elem), out);
} else {
V me = elem;
elemSer->serialize(&me, out);
}
}
}
std::vector<int8_t> serializeValue(const S& state, TypeSerializer* stateSerializer)
{
DataOutputSerializer outputSerializer;
OutputBufferStatus outputBufferStatus;
outputSerializer.setBackendBuffer(&outputBufferStatus);
if constexpr (IsEmhashMapPtr<S>::value) {
auto* mapSer = dynamic_cast<MapSerializer*>(stateSerializer);
if (mapSer && state != nullptr) {
serializeEmhashMap(*state, mapSer->getKeySerializer(), mapSer->getValueSerializer(), outputSerializer);
}
} else if constexpr (IsVectorPtr<S>::value) {
auto* listSer = dynamic_cast<ListSerializer*>(stateSerializer);
if (listSer && state != nullptr) {
serializeVector(*state, listSer->getElementSerializer(), outputSerializer);
} else {
stateSerializer->serialize(const_cast<S>(state), outputSerializer);
}
} else if constexpr (std::is_pointer_v<S>) {
stateSerializer->serialize(const_cast<S>(state), outputSerializer);
} else {
S mutableState = state;
stateSerializer->serialize(&mutableState, outputSerializer);
}
auto* data = outputSerializer.getData();
size_t len = outputSerializer.length();
std::vector<int8_t> result(len);
for (size_t i = 0; i < len; i++) {
result[i] = static_cast<int8_t>(data[i]);
}
return result;
}
};
#endif