* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#pragma once
#include "HeapPriorityQueue.h"
#include "state/KeyGroupedInternalPriorityQueue.h"
#include "state/KeyGroupRange.h"
#include "state/rocksdb/RocksDBCachingPriorityQueueSet.h"
#include "runtime/state/KeyGroupRangeAssignment.h"
#include "core/utils/type_traits_ext.h"
template <typename K, typename T, typename Comparator>
class KeyGroupPartitionedPriorityQueue : public KeyGroupedInternalPriorityQueue<T> {
static_assert(is_shared_ptr_v<T>, "T should be shared ptr.");
using DedupSet = typename KeyGroupedInternalPriorityQueue<T>::DedupSet;
class KeyGroupConcatenationIterator : public omnistream::utils::Iterator<T> {
public:
KeyGroupConcatenationIterator(KeyGroupPartitionedPriorityQueue* self) : self_(self), current(self_->keyGroupedHeaps_[0]->iterator()) {}
bool hasNext() override {
bool currentHasNext = current->hasNext();
while (!currentHasNext && index_ < self_->keyGroupedHeaps_.size() - 1) {
current = self_->keyGroupedHeaps_[++index_]->iterator();
currentHasNext = current->hasNext();
}
return currentHasNext;
}
T next() override {
return current->next();
}
private:
KeyGroupPartitionedPriorityQueue* self_;
int32_t index_ = 0;
std::unique_ptr<omnistream::utils::Iterator<T>> current = nullptr;
};
public:
using RocksDBCachingPQ = RocksDBCachingPriorityQueueSet<K, T, Comparator>;
using PartitionQueueSetFactory = std::function<std::shared_ptr<RocksDBCachingPQ>(int32_t keyGroupId, int32_t totalKeyGroups)>;
KeyGroupPartitionedPriorityQueue(
PartitionQueueSetFactory priorityQueueSetFactory,
KeyGroupRange* keyGroupRange,
int32_t totalKeyGroups)
:
heapOfKeyGroupedHeaps_(keyGroupRange->getNumberOfKeyGroups()),
keyGroupRange_(keyGroupRange),
totalKeyGroups_(totalKeyGroups) {
if (keyGroupRange_->getNumberOfKeyGroups() <= 0) {
THROW_LOGIC_EXCEPTION("KeyGroupRange must have at least one key group.")
}
firstKeyGroup_ = keyGroupRange_->getStartKeyGroup();
keyGroupedHeaps_.reserve(keyGroupRange_->getNumberOfKeyGroups());
for (auto i = 0; i < keyGroupRange_->getNumberOfKeyGroups(); i++) {
auto priorityQueue = priorityQueueSetFactory(firstKeyGroup_ + i, totalKeyGroups_);
keyGroupedHeaps_.emplace_back(priorityQueue);
heapOfKeyGroupedHeaps_.add(priorityQueue);
}
}
T poll() override {
auto headPQ = heapOfKeyGroupedHeaps_.peek();
auto head = headPQ->poll();
heapOfKeyGroupedHeaps_.adjustModifiedElement(headPQ);
return head;
}
T peek() override {
const auto& headPQ = heapOfKeyGroupedHeaps_.peek();
return headPQ->peek();
}
bool add(const T& element) override {
auto pq = getKeyGroupSubHeapForElement(element);
if (pq->add(element)) {
heapOfKeyGroupedHeaps_.adjustModifiedElement(pq);
const auto& newHead = peek();
return !comparator_(element, newHead) && !comparator_(newHead, element);
}
return false;
}
void addAll(const std::vector<T>& elements) override {
if (elements.empty()) {
return;
}
for (const auto& element : elements) {
add(element);
}
}
bool remove(const T& element) override {
auto pq = getKeyGroupSubHeapForElement(element);
auto oldHead = peek();
if (pq->remove(element)) {
heapOfKeyGroupedHeaps_.adjustModifiedElement(pq);
return !comparator_(element, oldHead) && !comparator_(oldHead, element);
}
return false;
}
bool isEmpty() override {
return peek() == nullptr;
}
int32_t size() override {
auto sizeSum = 0;
for (const auto& pq : keyGroupedHeaps_) {
sizeSum += pq->size();
}
return sizeSum;
}
std::unique_ptr<omnistream::utils::Iterator<T>> iterator() override {
return std::make_unique<KeyGroupConcatenationIterator>(this);
}
std::shared_ptr<DedupSet> getSubsetForKeyGroup(int32_t keyGroupId) override {
std::shared_ptr<DedupSet> result = std::make_shared<DedupSet>();
auto& partitionQueue = keyGroupedHeaps_[globalKeyGroupToLocalIndex(keyGroupId)];
auto iter = partitionQueue->iterator();
try {
while (iter->hasNext()) {
result->insert(iter->next());
}
} catch (const std::exception& e) {
THROW_RUNTIME_ERROR("Exception while iterating key group: " + std::string(e.what()));
}
return result;
}
private:
struct PQComparator {
bool operator()(std::shared_ptr<RocksDBCachingPQ> a, const std::shared_ptr<RocksDBCachingPQ>& b) const {
auto left = a->peek();
auto right = b->peek();
if (left == nullptr) {
return false;
}
return right == nullptr ? true : comparator_(left, right);
}
Comparator comparator_;
};
HeapPriorityQueue<std::shared_ptr<RocksDBCachingPQ>, PQComparator> heapOfKeyGroupedHeaps_;
std::vector<std::shared_ptr<RocksDBCachingPQ>> keyGroupedHeaps_;
KeyGroupRange* keyGroupRange_;
int32_t firstKeyGroup_;
int32_t totalKeyGroups_;
Comparator comparator_;
std::shared_ptr<RocksDBCachingPQ> getKeyGroupSubHeapForElement(T element) {
return keyGroupedHeaps_[computeKeyGroupIndex(element)];
}
int32_t computeKeyGroupIndex(const T& element) {
auto key = element->getKey();
int32_t keyGroupId = KeyGroupRangeAssignment<K>::assignToKeyGroup(key, totalKeyGroups_);
return globalKeyGroupToLocalIndex(keyGroupId);
}
int32_t globalKeyGroupToLocalIndex(int32_t keyGroup) const {
if (!keyGroupRange_->contains(keyGroup)) {
THROW_LOGIC_EXCEPTION("Key group " + std::to_string(keyGroup) + " is not included in" +
"[" + std::to_string(keyGroupRange_->getStartKeyGroup()) + ", " + std::to_string(keyGroupRange_->getEndKeyGroup()) + "]")
}
return keyGroup - keyGroupRange_->getStartKeyGroup();
}
};