* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_BSSSTATETABLE_H
#define OMNISTREAM_BSSSTATETABLE_H
#ifdef WITH_OMNISTATESTORE
#include "typeutils/LongSerializer.h"
#include "memory/DataInputDeserializer.h"
#include "utils/VectorBatchSerializationUtils.h"
#include "utils/VectorBatchDeserializationUtils.h"
#include "boost_state_table.h"
#include "boost_state_db.h"
#include "table_description.h"
#include "state/HashCode.h"
template <typename K, typename N, typename S>
class BssStateTable {
public:
BssStateTable(
InternalKeyContext<K>* keyContext,
RegisteredKeyValueStateBackendMetaInfo* metaInfo,
TypeSerializer* keySerializer)
: keyContext(keyContext),
metaInfo(metaInfo),
keySerializer(keySerializer) {};
bool isEmpty()
{
return size == 0;
}
void createTable(ock::bss::BoostStateDBPtr& _dbPtr)
{
this->dbPtr = _dbPtr;
auto tblDesc = std::make_shared<ock::bss::TableDescription>(
ock::bss::StateType::VALUE, "dbTable", -1, ock::bss::TableSerializer{}, dbPtr->GetConfig());
dbTable = std::dynamic_pointer_cast<ock::bss::KVTable>(_dbPtr->GetTableOrCreate(tblDesc));
};
S get(N& nameSpace)
{
LOG("bss state table get");
uint32_t keyHashCode;
DataOutputSerializer serializer;
OutputBufferStatus outputBufferStatus;
serializer.setBackendBuffer(&outputBufferStatus);
ock::bss::BinaryData priKey = GetPriBinaryData(nameSpace, keyHashCode, serializer);
ock::bss::BinaryData readValue;
auto res = dbTable->Get(keyHashCode, priKey, readValue);
if (res != ock::bss::BSS_OK || readValue.Length() == 0) {
if constexpr (std::is_pointer_v<S>) {
return nullptr;
} else {
return std::numeric_limits<S>::max();
}
}
DataInputDeserializer serializedData(reinterpret_cast<const uint8_t*>(readValue.Data()), readValue.Length(), 0);
void* resPtr = getStateSerializer()->deserialize(serializedData);
if constexpr (std::is_pointer_v<S>) {
return (S)resPtr;
} else {
return resPtr == nullptr ? std::numeric_limits<S>::max() : *(S*)resPtr;
}
}
void put(N& nameSpace, const S& state)
{
LOG("BSS state table put");
uint32_t keyHashCode;
DataOutputSerializer serializer;
OutputBufferStatus outputBufferStatus;
serializer.setBackendBuffer(&outputBufferStatus);
ock::bss::BinaryData priKey = GetPriBinaryData(nameSpace, keyHashCode, serializer);
TypeSerializer* vSerializer = getStateSerializer();
DataOutputSerializer valueOutputSerializer;
OutputBufferStatus valueOutputBufferStatus;
valueOutputSerializer.setBackendBuffer(&valueOutputBufferStatus);
S tmpS = state;
if constexpr (std::is_same_v<S, int64_t> || std::is_same_v<S, int32_t>) {
LongSerializer::INSTANCE->serialize(&tmpS, valueOutputSerializer);
} else if constexpr (std::is_pointer_v<S>) {
vSerializer->serialize(tmpS, valueOutputSerializer);
} else {
vSerializer->serialize(&tmpS, valueOutputSerializer);
}
ock::bss::BinaryData priValue(
valueOutputSerializer.getData(), static_cast<int32_t>(valueOutputSerializer.getPosition()));
auto res = dbTable->Put(keyHashCode, priKey, priValue);
if (res != ock::bss::BSS_OK) {
LOG("Warning: put failed");
}
}
void clear(N& nameSpace)
{
LOG("BSS state table clear");
uint32_t keyHashCode;
DataOutputSerializer serializer;
OutputBufferStatus outputBufferStatus;
serializer.setBackendBuffer(&outputBufferStatus);
ock::bss::BinaryData priKey = GetPriBinaryData(nameSpace, keyHashCode, serializer);
auto res = dbTable->Remove(keyHashCode, priKey);
if (res != ock::bss::BSS_OK) {
LOG("Warning: clear failed");
}
}
void add(N& nameSpace, const S& value)
{
LOG("BSS state table add");
uint32_t keyHashCode;
DataOutputSerializer serializer;
OutputBufferStatus outputBufferStatus;
serializer.setBackendBuffer(&outputBufferStatus);
ock::bss::BinaryData priKey = GetPriBinaryData(nameSpace, keyHashCode, serializer);
DataOutputSerializer valueOutputSerializer;
OutputBufferStatus valueOutputBufferStatus;
valueOutputSerializer.setBackendBuffer(&valueOutputBufferStatus);
S tmpS = value;
TypeSerializer* vSerializer = getStateSerializer();
if constexpr (std::is_same_v<S, int64_t> || std::is_same_v<S, int32_t>) {
LongSerializer::INSTANCE->serialize(&tmpS, valueOutputSerializer);
} else if constexpr (std::is_pointer_v<S>) {
vSerializer->serialize(tmpS, valueOutputSerializer);
} else {
vSerializer->serialize(&tmpS, valueOutputSerializer);
}
ock::bss::BinaryData priVal(
valueOutputSerializer.getData(), static_cast<int32_t>(valueOutputSerializer.getPosition()));
auto res = dbTable->Add(keyHashCode, priKey, priVal);
if (res != ock::bss::BSS_OK) {
LOG("Warning: add failed");
}
}
ock::bss::BinaryData GetPriBinaryData(N& nameSpace, uint32_t& keyHashCode, DataOutputSerializer& serializer)
{
auto currentKey = keyContext->getCurrentKey();
if constexpr (std::is_same_v<K, int64_t> || std::is_same_v<K, int32_t>) {
LongSerializer::INSTANCE->serialize(¤tKey, serializer);
} else if constexpr (std::is_pointer_v<K>) {
keySerializer->serialize(currentKey, serializer);
} else {
keySerializer->serialize(¤tKey, serializer);
}
if constexpr (std::is_same_v<N, int64_t> || std::is_same_v<N, int32_t>) {
LongSerializer::INSTANCE->serialize(&nameSpace, serializer);
} else if constexpr (std::is_pointer_v<N>) {
getNamespaceSerializer()->serialize(nameSpace, serializer);
} else {
getNamespaceSerializer()->serialize(&nameSpace, serializer);
}
ock::bss::BinaryData priBinaryData(serializer.getData(), static_cast<int32_t>(serializer.getPosition()));
keyHashCode = HashCode::Hash(serializer.getData(), static_cast<int32_t>(serializer.getPosition()));
return priBinaryData;
}
long getVectorBatchesSize()
{
return vectorBatchId;
}
void addVectorBatch(omnistream::VectorBatch* vectorBatch)
{
LOG("Bss state table addVectorBatch");
DataOutputSerializer keyOutputSerializer;
OutputBufferStatus outputBufferStatus;
keyOutputSerializer.setBackendBuffer(&outputBufferStatus);
LongSerializer longSerializer;
longSerializer.serialize(&vectorBatchId, keyOutputSerializer);
ock::bss::BinaryData priKey(
keyOutputSerializer.getData(), static_cast<int32_t>(keyOutputSerializer.getPosition()));
uint32_t keyHashCode =
HashCode::Hash(keyOutputSerializer.getData(), static_cast<int32_t>(keyOutputSerializer.getPosition()));
int batchSize = omnistream::VectorBatchSerializationUtils::calculateVectorBatchSerializableSize(vectorBatch);
uint8_t* buffer = new uint8_t[batchSize];
omnistream::SerializedBatchInfo serializedBatchInfo =
omnistream::VectorBatchSerializationUtils::serializeVectorBatch(vectorBatch, batchSize, buffer);
ock::bss::BinaryData priVal(serializedBatchInfo.buffer, serializedBatchInfo.size);
auto res = dbTable->Put(keyHashCode, priKey, priVal);
if (res != ock::bss::BSS_OK) {
LOG("Warning: addVectorBatch failed");
}
vectorBatchId++;
}
omnistream::VectorBatch* getVectorBatch(long batchId)
{
LOG("Bss state table getVectorBatch");
DataOutputSerializer keyOutputSerializer;
OutputBufferStatus outputBufferStatus;
keyOutputSerializer.setBackendBuffer(&outputBufferStatus);
LongSerializer longSerializer;
longSerializer.serialize(&batchId, keyOutputSerializer);
uint32_t keyHashCode =
HashCode::Hash(keyOutputSerializer.getData(), static_cast<int32_t>(keyOutputSerializer.getPosition()));
ock::bss::BinaryData priKey(
keyOutputSerializer.getData(), static_cast<uint32_t>(keyOutputSerializer.getPosition()));
ock::bss::BinaryData priVal;
auto res = dbTable->Get(keyHashCode, priKey, priVal);
if (res != ock::bss::BSS_OK) {
LOG("Warning: getVectorBatch failed");
}
uint8_t* address = const_cast<uint8_t*>(priVal.Data() + sizeof(int8_t));
auto batch = omnistream::VectorBatchDeserializationUtils::deserializeVectorBatch(address);
return batch;
}
TypeSerializer* getNamespaceSerializer()
{
return metaInfo->getNamespaceSerializer();
}
TypeSerializer* getStateSerializer()
{
return metaInfo->getStateSerializer();
}
RegisteredKeyValueStateBackendMetaInfo* getMetaInfo()
{
return metaInfo;
}
void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo* newMetaInfo)
{
metaInfo = newMetaInfo;
}
protected:
InternalKeyContext<K>* keyContext;
RegisteredKeyValueStateBackendMetaInfo* metaInfo;
TypeSerializer* keySerializer;
ock::bss::BoostStateDBPtr dbPtr;
ock::bss::ConfigRef config;
ock::bss::KVTableRef dbTable;
int size = 0;
long vectorBatchId = 0;
};
#endif
#endif