* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_STATEOBJECTCOLLECTION_H
#define OMNISTREAM_STATEOBJECTCOLLECTION_H
#include <vector>
#include <list>
#include <memory>
#include <functional>
#include <algorithm>
#include <iterator>
#include <stdexcept>
#include <string>
#include <sstream>
#include <set>
#include <unordered_set>
#include "runtime/state/StateObject.h"
#include "runtime/state/CompositeStateHandle.h"
#include "core/include/common.h"
#include "runtime/state/InputChannelStateHandle.h"
#include "runtime/state/ResultSubpartitionStateHandle.h"
#include "runtime/snapshot/RocksDBSnapshotStrategyBase.h"
template<typename T>
class StateObjectCollection : public StateObject {
public:
StateObjectCollection()
{
this->stateObjects = std::vector<std::shared_ptr<T>>();
}
* Creates a new StateObjectCollection wraps the given collection and delegates to it.
*
* @param stateObjects collection of state objects to wrap.
*/
explicit StateObjectCollection(std::vector<std::shared_ptr<T>> stateObjects)
: stateObjects(stateObjects) {}
int Size() const
{
return static_cast<int>(stateObjects.size());
}
bool IsEmpty() const
{
return stateObjects.empty();
}
bool Contains(const std::shared_ptr<T> &element) const
{
return std::find(stateObjects.begin(), stateObjects.end(), element) != stateObjects.end();
}
typename std::vector<std::shared_ptr<T>>::iterator begin()
{
return stateObjects.begin();
}
typename std::vector<std::shared_ptr<T>>::iterator end()
{
return stateObjects.end();
}
typename std::vector<std::shared_ptr<T>>::const_iterator begin() const
{
return stateObjects.begin();
}
typename std::vector<std::shared_ptr<T>>::const_iterator end() const
{
return stateObjects.end();
}
std::vector<std::shared_ptr<T>> ToArray() const
{
return stateObjects;
}
bool Add(const std::shared_ptr<T> &element)
{
stateObjects.push_back(element);
return true;
}
bool Remove(const std::shared_ptr<T> &element)
{
auto it = std::find(stateObjects.begin(), stateObjects.end(), element);
if (it != stateObjects.end()) {
stateObjects.erase(it);
return true;
}
return false;
}
bool ContainsAll(const std::vector<std::shared_ptr<T>> &c) const
{
for (const auto &element: c) {
if (!contains(element)) {
return false;
}
}
return true;
}
bool AddAll(const std::vector<std::shared_ptr<T>> &c)
{
bool changed = false;
for (const auto &element: c) {
if (add(element)) {
changed = true;
}
}
return changed;
}
bool RemoveAll(const std::vector<std::shared_ptr<T>> &c)
{
bool changed = false;
for (const auto &element: c) {
if (remove(element)) {
changed = true;
}
}
return changed;
}
bool RemoveIf(std::function<bool(const std::shared_ptr<T> &)> filter)
{
auto it = std::remove_if(stateObjects.begin(), stateObjects.end(), filter);
bool changed = it != stateObjects.end();
stateObjects.erase(it, stateObjects.end());
return changed;
}
bool RetainAll(const std::vector<std::shared_ptr<T>> &c)
{
std::unordered_set<std::shared_ptr<T>> retainSet(c.begin(), c.end());
auto it = std::remove_if(stateObjects.begin(), stateObjects.end(),
[&retainSet](const std::shared_ptr<T> &element) {
return retainSet.find(element) == retainSet.end();
});
bool changed = it != stateObjects.end();
stateObjects.erase(it, stateObjects.end());
return changed;
}
void Clear()
{
stateObjects.clear();
}
void DiscardState() override
{
for (const auto &object: stateObjects) {
object->DiscardState();
}
}
long GetStateSize() const override
{
return SumAllSizes(stateObjects);
}
long GetCheckpointedSize() const
{
return SumAllCheckpointedSizes(stateObjects);
}
bool HasState() const
{
for (const auto &state: stateObjects) {
if (state != nullptr) {
return true;
}
}
return false;
}
bool operator==(const StateObjectCollection<T> &other) const
{
if (this == &other) {
return true;
}
NOT_IMPL_EXCEPTION
}
bool operator!=(const StateObjectCollection<T> &other) const
{
return !(*this == other);
}
size_t HashCode() const
{
size_t hash = 0;
for (const auto &obj: stateObjects) {
hash ^= std::hash<std::shared_ptr<T>>{}(obj);
}
return hash;
}
std::string ToString() const override
{
nlohmann::json j;
nlohmann::json state_objects_array = nlohmann::json::array();
for (const auto& obj : stateObjects) {
if (obj) {
auto derived = std::dynamic_pointer_cast<BridgeKeyedStateHandle>(obj);
if (derived != nullptr && derived->handle != nullptr) {
std::string jsonStr = derived->handle->ToString();
state_objects_array.push_back(nlohmann::json::parse(jsonStr));
} else{
state_objects_array.push_back(nlohmann::json::parse(obj->ToString()));
}
} else {
state_objects_array.push_back(nullptr);
}
}
j["stateObjects"] = state_objects_array;
return j.dump();
}
std::vector<std::shared_ptr<T>> AsList() const
{
return stateObjects;
}
static std::shared_ptr<StateObjectCollection<T>> Empty()
{
if (!EMPTY) {
EMPTY = std::make_shared<StateObjectCollection<T>>(std::vector<std::shared_ptr<T>>());
}
return EMPTY;
}
static std::shared_ptr<StateObjectCollection<T>> EmptyIfNull(std::shared_ptr<StateObjectCollection<T>> collection)
{
return collection == nullptr ? Empty() : collection;
}
static std::shared_ptr<StateObjectCollection<T>> Singleton(std::shared_ptr<T> stateObject)
{
std::vector<std::shared_ptr<T>> vec = {stateObject};
return std::make_shared<StateObjectCollection<T>>(vec);
}
static std::shared_ptr<StateObjectCollection<T>> SingletonOrEmpty(std::shared_ptr<T> stateObject)
{
return stateObject == nullptr ? Empty() : Singleton(stateObject);
}
private:
static const long serialVersionUID = 1L;
static std::shared_ptr<StateObjectCollection<T>> EMPTY;
std::vector<std::shared_ptr<T>> stateObjects;
static long SumAllSizes(const std::vector<std::shared_ptr<T>> &stateObjects)
{
long size = 0L;
for (const auto &object: stateObjects) {
size += GetSizeNullSafe(object);
}
return size;
}
static long GetSizeNullSafe(const std::shared_ptr<T> &stateObject)
{
return stateObject != nullptr ? stateObject->GetStateSize() : 0L;
}
static long SumAllCheckpointedSizes(const std::vector<std::shared_ptr<T>> &stateObjects)
{
long size = 0L;
for (const auto &object: stateObjects) {
size += GetCheckpointedSizeNullSafe(object);
}
return size;
}
static long GetCheckpointedSizeNullSafe(const std::shared_ptr<T> &stateObject)
{
auto composite = std::dynamic_pointer_cast<CompositeStateHandle>(stateObject);
return composite != nullptr ? composite->GetCheckpointedSize() : GetSizeNullSafe(stateObject);
}
};
template<typename T> inline std::shared_ptr<StateObjectCollection<T>> StateObjectCollection<T>::EMPTY
= std::make_shared<StateObjectCollection<T>>(std::vector<std::shared_ptr<T>>());
#endif