* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef FLINK_TNEL_JOINRECORDSTATEVIEW_H
#define FLINK_TNEL_JOINRECORDSTATEVIEW_H
#include <set>
#include <vector>
#include <nlohmann/json.hpp>
#include <typeinfo>
#include <xxhash.h>
#include "table/runtime/keyselector/KeySelector.h"
#include "table/data/RowData.h"
#include "core/api/common/state/ValueState.h"
#include "core/api/common/state/StateDescriptor.h"
#include "core/typeutils/TypeSerializer.h"
#include "core/typeutils/XxH128_hashSerializer.h"
#include "core/typeutils/JoinTupleSerializer.h"
#include "streaming/api/operators/StreamingRuntimeContext.h"
#include "table/typeutils/InternalTypeInfo.h"
#include "runtime/state/heap/HeapMapState.h"
#include "table/data/binary/BinaryRowData.h"
#include "table/data/vectorbatch/VectorBatch.h"
#include "table/data/util/VectorBatchUtil.h"
#include "table/data/util/RowDataUtil.h"
template <typename K>
class JoinRecordStateView {
public:
virtual void addOrRectractRecord(
omnistream::VectorBatch* input,
KeySelector<K>* keySelector,
bool OtherIsOuter,
KeyedStateBackend<K>* backend,
bool filterNulls,
const std::vector<int32_t>& numAssociates) = 0;
virtual ~JoinRecordStateView() = default;
virtual void addVectorBatch(omnistream::VectorBatch* vectorBatch) = 0;
virtual int getCurrentBatchId() const = 0;
virtual std::vector<omnistream::VectorBatch*> getVectorBatches() const = 0;
virtual omnistream::VectorBatch* getVectorBatch(int batchId) = 0;
virtual long getVectorBatchesSize() = 0;
virtual void freeDelVectorBatch()
{
this->cachedVb.clear();
if (backendType_ == omnistream::StateType::HEAP) {
this->delVb.clear();
return;
}
for (auto vb : this->delVb) {
delete vb;
}
this->delVb.clear();
};
virtual void cleanEntriesCache() = 0;
protected:
std::unordered_map<int, omnistream::VectorBatch*> cachedVb;
std::set<omnistream::VectorBatch*> delVb;
omnistream::StateType backendType_ = omnistream::StateType::HEAP;
};
template <typename K>
class InputSideHasNoUniqueKey : public JoinRecordStateView<K> {
public:
using UV = std::tuple<int32_t, int64_t>;
using MAP_STATE_TYPE = MapState<XXH128_hash_t, UV>;
using MAP_TYPE = emhash7::HashMap<XXH128_hash_t, UV>;
InputSideHasNoUniqueKey(StreamingRuntimeContext<K>* ctx, std::string stateName, InternalTypeInfo* recordType)
{
auto* recordStateDesc = new MapStateDescriptor<XXH128_hash_t, UV>(
stateName, new XxH128_hashSerializer(), new JoinTupleSerializer());
recordStateDesc->setKeyValueBackendTypeId(BackendDataType::XXHASH128_BK, BackendDataType::TUPLE_INT32_INT64);
this->recordStateVB = ctx->template getMapState<XXH128_hash_t, UV>(recordStateDesc);
#ifdef WITH_OMNISTATESTORE
if (auto* backend = dynamic_cast<BssMapState<K, VoidNamespace, XXH128_hash_t, UV>*>(recordStateVB)) {
INFO_RELEASE("InputSideHasNoUniqueKey backend is bss");
this->backendType_ = omnistream::StateType::BSS;
}
#endif
if (auto* backend = dynamic_cast<RocksdbMapState<K, VoidNamespace, XXH128_hash_t, UV>*>(recordStateVB)) {
INFO_RELEASE("InputSideHasNoUniqueKey backend is rocksdb");
this->backendType_ = omnistream::StateType::ROCKSDB;
} else {
INFO_RELEASE("InputSideHasNoUniqueKey backend is mem");
this->backendType_ = omnistream::StateType::HEAP;
}
}
~InputSideHasNoUniqueKey() override {};
virtual long getVectorBatchesSize()
{
return recordStateVB->getVectorBatchesSize();
};
void addVectorBatch(omnistream::VectorBatch* vectorBatch) override
{
this->delVb.insert(vectorBatch);
recordStateVB->addVectorBatch(vectorBatch);
}
omnistream::VectorBatch* getVectorBatch(int batchId)
{
if (this->cachedVb.end() != this->cachedVb.find((batchId))) {
return this->cachedVb[batchId];
}
auto vb = recordStateVB->getVectorBatch(batchId);
this->delVb.insert(vb);
this->cachedVb.emplace(batchId, vb);
return vb;
}
int getCurrentBatchId() const override
{
return recordStateVB->getVectorBatchesSize();
};
MAP_TYPE* getRecords()
{
return recordStateVB->entries();
}
[[nodiscard]] std::vector<omnistream::VectorBatch*> getVectorBatches() const override
{
return recordStateVB->getVectorBatches();
}
[[nodiscard]] MAP_STATE_TYPE* getState() const
{
return recordStateVB;
};
void addOrRectractRecord(
omnistream::VectorBatch* input,
KeySelector<K>* keySelector,
bool otherIsOuter,
KeyedStateBackend<K>* backend,
bool filterNulls,
const std::vector<int32_t>& numAssociates) override;
void cleanEntriesCache()
{
recordStateVB->clearEntriesCache();
}
private:
MAP_STATE_TYPE* recordStateVB;
void GetRockDBRecordsInBatch(
std::vector<XXH128_hash_t>& xxh128Hashes,
std::vector<K>& keys,
std::unordered_map<std::pair<K, XXH128_hash_t>, UV>& result);
void ProcessRockDBRecordsInBatch(
omnistream::VectorBatch* input, KeySelector<K>* keySelector, bool filterNulls, int batchId);
};
template <typename K>
void InputSideHasNoUniqueKey<K>::addOrRectractRecord(
omnistream::VectorBatch* input,
KeySelector<K>* keySelector,
bool otherIsOuter,
KeyedStateBackend<K>* backend,
bool filterNulls,
const std::vector<int32_t>& numAssociates)
{
LOG(">>>>>>>");
int32_t batchId = getCurrentBatchId();
this->addVectorBatch(input);
if (this->backendType_ == omnistream::StateType::ROCKSDB) {
ProcessRockDBRecordsInBatch(input, keySelector, filterNulls, batchId);
return;
}
std::vector<XXH128_hash_t> xxh128Hashes = input->getXXH128s();
long* comboIDs = new long[input->GetRowCount()];
VectorBatchUtil::getComboId_sve(batchId, input->GetRowCount(), comboIDs);
for (int i = 0; i < input->GetRowCount(); i++) {
if (filterNulls && keySelector->isAnyKeyNull(input, i)) {
continue;
}
auto key = keySelector->getKey(input, i);
backend->setCurrentKey(key);
XXH128_hash_t ukey = xxh128Hashes[i];
int delta = RowDataUtil::isAccumulateMsg(input->getRowKind(i)) ? +1 : -1;
recordStateVB->updateOrCreate(
ukey,
UV{1, comboIDs[i]},
[delta, &numAssociates, i](UV& val) -> std::optional<UV> {
int newCount = std::get<0>(val) + delta;
if (newCount != 0) {
return UV{newCount, std::get<1>(val)};
} else {
return std::nullopt;
}
});
if constexpr (std::is_pointer_v<K>) {
delete key;
}
}
delete[] comboIDs;
}
template <typename K>
void InputSideHasNoUniqueKey<K>::ProcessRockDBRecordsInBatch(
omnistream::VectorBatch* input, KeySelector<K>* keySelector, bool filterNulls, int batchId)
{
std::vector<XXH128_hash_t> xxh128Hashes = input->getXXH128s();
std::unordered_map<std::pair<K, XXH128_hash_t>, UV> rockdbRecords;
std::vector<K> keys;
std::vector<int> deltas;
for (int i = 0; i < input->GetRowCount(); i++) {
if (filterNulls && keySelector->isAnyKeyNull(input, i)) {
if constexpr (std::is_pointer_v<K>) {
keys.push_back(nullptr);
} else {
keys.push_back(K{});
}
deltas.push_back(0);
continue;
}
auto key = keySelector->getKey(input, i);
keys.push_back(key);
int delta = RowDataUtil::isAccumulateMsg(input->getRowKind(i)) ? +1 : -1;
deltas.push_back(delta);
}
GetRockDBRecordsInBatch(xxh128Hashes, keys, rockdbRecords);
std::unordered_map<K, std::unordered_map<XXH128_hash_t, UV>> addOrUpdateRecords;
std::unordered_map<K, std::unordered_set<XXH128_hash_t>> deleteRecords;
for (int i = 0; i < keys.size(); i++) {
if constexpr (std::is_pointer_v<K>) {
if (keys[i] == nullptr) continue;
} else {
if (keys[i] == K{}) {
continue;
}
}
auto keyAndUkey = std::make_pair(keys[i], xxh128Hashes[i]);
auto recordIter = rockdbRecords.find(keyAndUkey);
UV val{0, VectorBatchUtil::getComboId(batchId, i)};
if (recordIter != rockdbRecords.end()) {
val = recordIter->second;
}
int newCount = std::get<0>(val) + deltas[i];
if (newCount > 0) {
UV newVal{newCount, std::get<1>(val)};
rockdbRecords[keyAndUkey] = newVal;
auto& inner = addOrUpdateRecords[keys[i]];
inner.insert_or_assign(xxh128Hashes[i], std::move(newVal));
auto it = deleteRecords.find(keys[i]);
if (it != deleteRecords.end()) {
auto& innerSet = it->second;
innerSet.erase(xxh128Hashes[i]);
if (innerSet.empty()) {
deleteRecords.erase(it);
}
}
} else {
rockdbRecords.erase(keyAndUkey);
auto it = addOrUpdateRecords.find(keys[i]);
if (it != addOrUpdateRecords.end()) {
auto& inner = it->second;
if (inner.find(xxh128Hashes[i]) != inner.end()) {
inner.erase(xxh128Hashes[i]);
if (inner.empty()) {
addOrUpdateRecords.erase(it);
}
}
}
deleteRecords[keys[i]].emplace(xxh128Hashes[i]);
}
}
auto rockState = dynamic_cast<RocksdbMapState<K, VoidNamespace, XXH128_hash_t, UV>*>(recordStateVB);
if (rockState) {
rockState->putByBatch(addOrUpdateRecords);
rockState->removeByBatch(deleteRecords);
}
for (int i = 0; i < keys.size(); i++) {
if constexpr (std::is_pointer_v<K>) {
if (keys[i] != nullptr) {
delete keys[i];
}
}
}
}
template <typename K>
void InputSideHasNoUniqueKey<K>::GetRockDBRecordsInBatch(
std::vector<XXH128_hash_t>& xxh128Hashes,
std::vector<K>& keys,
std::unordered_map<std::pair<K, XXH128_hash_t>, UV>& result)
{
auto rockState = dynamic_cast<RocksdbMapState<K, VoidNamespace, XXH128_hash_t, UV>*>(recordStateVB);
if (rockState) {
std::unordered_map<K, std::unordered_set<XXH128_hash_t>> keyAndUkeysMap;
for (int i = 0; i < keys.size(); i++) {
auto joinKey = keys.at(i);
if constexpr (std::is_pointer_v<K>) {
if (keys[i] == nullptr) {
continue;
}
} else {
if (keys[i] == K{})
{
continue;
}
}
XXH128_hash_t userRecordHash = xxh128Hashes[i];
keyAndUkeysMap[joinKey].insert(userRecordHash);
}
rockState->GetByBatch(keyAndUkeysMap, result);
}
}
class JoinRecordStateViews {
public:
template <typename K>
static JoinRecordStateView<K>* create(
StreamingRuntimeContext<K>* ctx,
std::string stateName,
InternalTypeInfo* recordType,
InternalTypeInfo* UniqueKeyType,
std::vector<int32_t>& uniqueKeyIndex);
};
template <typename K>
JoinRecordStateView<K>* JoinRecordStateViews::create(
StreamingRuntimeContext<K>* ctx,
std::string stateName,
InternalTypeInfo* recordType,
InternalTypeInfo* UniqueKeyType,
std::vector<int32_t>& uniqueKeyIndex)
{
if (stateName.find("JoinKeyContainsUniqueKey") != std::string::npos) {
NOT_IMPL_EXCEPTION;
} else if (stateName.find("HasUnique") != std::string::npos) {
NOT_IMPL_EXCEPTION;
} else {
LOG("creating InputSideHasNoUniqueKey");
return new InputSideHasNoUniqueKey<K>(ctx, stateName, recordType);
}
}
namespace std {
template <typename K>
struct hash<std::pair<K, XXH128_hash_t>> {
size_t operator()(const std::pair<K, XXH128_hash_t>& p) const noexcept
{
size_t h1 = std::hash<K>{}(p.first);
size_t h2 = std::hash<XXH128_hash_t>{}(p.second);
return h1 ^ (h2 << 1);
}
};
template <typename K>
struct equal_to<std::pair<K, XXH128_hash_t>> {
bool operator()(const std::pair<K, XXH128_hash_t>& lhs, const std::pair<K, XXH128_hash_t>& rhs) const noexcept
{
return lhs.first == rhs.first && lhs.second.low64 == rhs.second.low64 && lhs.second.high64 == rhs.second.high64;
}
};
}
#endif