* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef FLINK_TNEL_STATETABLE_H
#define FLINK_TNEL_STATETABLE_H
#include <vector>
#include <type_traits>
#include <tuple>
#include <limits>
#include <functional>
#include "StateMap.h"
#include "core/typeutils/TypeSerializer.h"
#include "../StateTransformationFunction.h"
#include "../internal/InternalKvState.h"
#include "../InternalKeyContext.h"
#include "../KeyGroupRange.h"
#include "../RegisteredKeyValueStateBackendMetaInfo.h"
#include "table/data/binary/BinaryRowData.h"
* like RowData* for HeapValueState,
* emhash7<RowData*, int>* for HeapMapState,
* vector<int64_t>* for List State
*/
template<typename K, typename N, typename S>
class StateTable {
public:
StateTable(InternalKeyContext<K> *keyContext, RegisteredKeyValueStateBackendMetaInfo *metaInfo,
TypeSerializer *keySerializer);
virtual ~StateTable();
bool isEmpty()
{
return size() == 0;
};
int size();
S get(const N &nameSpace)
{
return get(keyContext->getCurrentKey(), keyContext->getCurrentKeyGroupIndex(), nameSpace);
};
bool containsKey(const N &nameSpace)
{
return containsKey(keyContext->getCurrentKey(), keyContext->getCurrentKeyGroupIndex(), nameSpace);
};
void put(const N &nameSpace, const S &state)
{
put(keyContext->getCurrentKey(), keyContext->getCurrentKeyGroupIndex(), nameSpace, state);
};
void put(const K &key, int keyGroup, const N &nameSpace, const S &state)
{
getMapForKeyGroup(keyGroup)->put(key, nameSpace, state);
}
K putAndRemoveDuplicateKey(const N &nameSpace, const S &state)
{
return getMapForKeyGroup(keyContext->getCurrentKeyGroupIndex())->putAndRemoveDuplicateKey(
keyContext->getCurrentKey(), nameSpace, state);
}
void remove(const N &nameSpace)
{
remove(keyContext->getCurrentKey(), keyContext->getCurrentKeyGroupIndex(), nameSpace);
};
S removeAndGetOld(const N &nameSpace)
{
return removeAndGetOld(keyContext->getCurrentKey(), keyContext->getCurrentKeyGroupIndex(), nameSpace);
};
template<typename T>
void transform(const N &nameSpace, T value, StateTransformationFunction<S, T> transformation);
S get(const K &key, const N &nameSpace);
S get(const K &key, int keyGroupIndex, const N &nameSpace);
typename InternalKvState<K, N, S>::StateIncrementalVisitor *
getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords);
RegisteredKeyValueStateBackendMetaInfo *getMetaInfo()
{
return metaInfo;
}
KeyGroupRange *getKeyGroupRange()
{
return keyGroupRange;
}
int getNumberOfKeyGroups()
{
return keyContext->getNumberOfKeyGroups();
}
void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo *newMetaInfo)
{
metaInfo = newMetaInfo;
}
void copyCurrentKey()
{
if constexpr (std::is_same_v<K, RowData *>) {
auto currentKey = static_cast<BinaryRowData *>(keyContext->getCurrentKey());
if (!currentKey) {
INFO_RELEASE("current key is null")
throw std::runtime_error("current key is null");
}
auto newKey = currentKey->copy();
keyContext->setCurrentKey(newKey);
} else if constexpr (std::is_pointer_v<K>) {
NOT_IMPL_EXCEPTION;
} else {
return;
}
}
std::vector<K> *getKeys(const N &nameSpace);
std::vector<std::tuple<K, N>> *getKeysAndNamespace();
std::vector<StateMap<K, N, S> *> *getState()
{
return &keyGroupedStateMaps;
}
int getKeyGroupOffset()
{
return keyGroupRange->getStartKeyGroup();
}
StateMap<K, N, S> *getMapForKeyGroup(int keyGroupIndex);
TypeSerializer *getKeySerializer()
{
return keySerializer;
}
TypeSerializer *getStateSerializer()
{
return metaInfo->getStateSerializer();
}
TypeSerializer *getNamespaceSerializer()
{
return metaInfo->getNamespaceSerializer();
}
void deleteMaps()
{
for (int index = keyGroupRange->getStartKeyGroup(); index <= keyGroupRange->getEndKeyGroup(); index++) {
int pos = indexToOffset(index);
if (pos >= 0 && pos < static_cast<int>(keyGroupedStateMaps.size())) {
delete keyGroupedStateMaps[pos];
keyGroupedStateMaps[pos] = nullptr;
}
}
};
InternalKeyContext<K> *getKeyContext()
{
return keyContext;
}
class StateEntryIterator : public InternalKvState<K, N, S>::StateIncrementalVisitor {
public:
S nextEntries() override;
StateEntryIterator(int recommendedMaxNumberOfReturnedRecords, StateTable<K, N, S> *table);
bool hasNext() override;
private:
int recommendedMaxNumberOfReturnedRecords;
int keyGroupIndex;
typename InternalKvState<K, N, S>::StateIncrementalVisitor* stateIncrementalVisitor;
StateTable<K, N, S> *table;
void init();
};
protected:
InternalKeyContext<K> *keyContext;
TypeSerializer *keySerializer;
KeyGroupRange *keyGroupRange;
std::vector<StateMap<K, N, S> *> keyGroupedStateMaps;
RegisteredKeyValueStateBackendMetaInfo *metaInfo;
virtual StateMap<K, N, S> *createStateMap() = 0;
bool containsKey(const K &key, int keyGroupIndex, const N &nameSpace)
{
return getMapForKeyGroup(keyGroupIndex)->containsKey(key, nameSpace);
};
void remove(const K &key, int keyGroupIndex, const N &nameSpace)
{
getMapForKeyGroup(keyGroupIndex)->remove(key, nameSpace);
};
S removeAndGetOld(const K &key, int keyGroupIndex, const N &nameSpace)
{
return getMapForKeyGroup(keyGroupIndex)->removeAndGetOld(key, nameSpace);
};
int indexToOffset(int index)
{
return index - getKeyGroupOffset();
};
virtual void initialize() = 0;
};
template<typename K, typename N, typename S>
StateTable<K, N, S>::StateTable(InternalKeyContext<K> *keyContext, RegisteredKeyValueStateBackendMetaInfo *metaInfo,
TypeSerializer *keySerializer)
{
this->keyContext = keyContext;
this->metaInfo = metaInfo;
this->keySerializer = keySerializer;
this->keyGroupRange = keyContext->getKeyGroupRange();
keyGroupedStateMaps = {};
}
template<typename K, typename N, typename S>
StateTable<K, N, S>::~StateTable()
{
delete metaInfo;
for (auto stateMapPtr: keyGroupedStateMaps) {
delete stateMapPtr;
}
}
template<typename K, typename N, typename S>
int StateTable<K, N, S>::size()
{
int count = 0;
for (int i = 0; i < keyGroupedStateMaps.size(); i++) {
if (keyGroupedStateMaps[i] != nullptr) {
count += keyGroupedStateMaps[i]->size();
}
}
return count;
}
template<typename K, typename N, typename S>
S StateTable<K, N, S>::get(const K &key, const N &nameSpace)
{
std::hash<K> keyHash;
int keyGroup = keyHash(key) % keyContext->getNumberOfKeyGroups();
return get(key, keyGroup, nameSpace);
}
template<typename K, typename N, typename S>
S StateTable<K, N, S>::get(const K &key, int keyGroupIndex, const N &nameSpace)
{
StateMap<K, N, S> *stateMap = getMapForKeyGroup(keyGroupIndex);
if (stateMap == nullptr) {
if constexpr (std::is_pointer_v<S>) {
return nullptr;
} else {
return std::numeric_limits<S>::max();
}
}
return stateMap->get(key, nameSpace);
};
template<typename K, typename N, typename S>
std::vector<K> *StateTable<K, N, S>::getKeys(const N &nameSpace)
{
return nullptr;
}
template<typename K, typename N, typename S>
std::vector<std::tuple<K, N>> *StateTable<K, N, S>::getKeysAndNamespace()
{
return nullptr;
}
template<typename K, typename N, typename S>
typename InternalKvState<K, N, S>::StateIncrementalVisitor *StateTable<K, N, S>::getStateIncrementalVisitor(
int recommendedMaxNumberOfReturnedRecords)
{
return new typename StateTable<K, N, S>::StateEntryIterator(recommendedMaxNumberOfReturnedRecords, this);
}
template<typename K, typename N, typename S>
StateMap<K, N, S> *StateTable<K, N, S>::getMapForKeyGroup(int keyGroupIndex)
{
const int pos = indexToOffset(keyGroupIndex);
if (keyContext == nullptr || keyContext->getKeyGroupRange() == nullptr) {
throw std::logic_error("Key context or key group range is null");
}
if (pos >= 0 && pos < keyContext->getKeyGroupRange()->getNumberOfKeyGroups()) {
return keyGroupedStateMaps[pos];
} else {
throw std::logic_error("Key group is not in key group range");
}
}
template<typename K, typename N, typename S>
StateTable<K, N, S>::StateEntryIterator::StateEntryIterator(int recommendedMaxNumberOfReturnedRecords,
StateTable<K, N, S> *table)
{
this->table = table;
this->recommendedMaxNumberOfReturnedRecords = recommendedMaxNumberOfReturnedRecords;
this->keyGroupIndex = 0;
init();
}
template<typename K, typename N, typename S>
void StateTable<K, N, S>::StateEntryIterator::init()
{
while (keyGroupIndex < table->keyContext->getKeyGroupRange()->getNumberOfKeyGroups()) {
StateMap<K, N, S> *stateMap = table->keyGroupedStateMaps[keyGroupIndex++];
auto visitor = stateMap->getStateIncrementalVisitor(recommendedMaxNumberOfReturnedRecords);
if (visitor->hasNext()) {
stateIncrementalVisitor = visitor;
return;
}
}
}
template<typename K, typename N, typename S>
bool StateTable<K, N, S>::StateEntryIterator::hasNext()
{
while (stateIncrementalVisitor == nullptr || !stateIncrementalVisitor->hasNext()) {
if (keyGroupIndex == table->keyContext->getKeyGroupRange()->getNumberOfKeyGroups()) {
return false;
}
auto visitor = table->keyGroupedStateMaps[keyGroupIndex++]->getStateIncrementalVisitor(recommendedMaxNumberOfReturnedRecords);
if (visitor->hasNext()) {
stateIncrementalVisitor = visitor;
break;
}
}
return true;
}
template<typename K, typename N, typename S>
S StateTable<K, N, S>::StateEntryIterator::nextEntries()
{
if (!hasNext()) {
if constexpr (std::is_pointer_v<S>) {
return nullptr;
} else {
return std::numeric_limits<S>::max();
}
}
return stateIncrementalVisitor->nextEntries();
}
template<typename K, typename N, typename S>
template<typename T>
void StateTable<K, N, S>::transform(const N &nameSpace, T value, StateTransformationFunction<S, T> transformation)
{
K key = keyContext->getCurrentKey();
int keyGroup = keyContext->getCurrentKeyGroupIndex();
getMapForKeyGroup(keyGroup).transform(key, nameSpace, value, transformation);
}
#endif