* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_TOPNBUFFER_H
#define OMNISTREAM_TOPNBUFFER_H
#include <map>
#include "table/data/RowData.h"
#include "table/runtime/keyselector/KeySelector.h"
using Comparator = bool (*) (RowData*, RowData*);
struct DescendingComparator {
bool operator()(RowData* a, RowData* b) const
{
return *(a->getLong(0)) > *(b->getLong(0));
}
};
struct AscendingComparator {
bool operator()(RowData* a, RowData* b) const
{
return *(a->getLong(0)) < *(b->getLong(0));
}
};
class TopNBuffer {
public:
TopNBuffer() = default;
auto inline begin()
{
return treeMap.begin();
}
auto inline end()
{
return treeMap.end();
}
int put(RowData* sortKey, RowData* value)
{
currentTopNum += 1;
auto [it, inserted] = treeMap.try_emplace(sortKey, nullptr);
if (inserted) {
auto newList = new std::vector<RowData*>(1, value);
it->second = newList;
return 1;
} else {
it->second->push_back(value);
return it->second->size();
}
}
auto get(RowData* sortKey)
{
return treeMap[sortKey];
}
bool contains(RowData* sortKey)
{
return treeMap.find(sortKey) != treeMap.end();
}
void removeAll(RowData* sortKey)
{
auto it = treeMap.find(sortKey);
if (it != treeMap.end()) {
currentTopNum -= static_cast<int>(it->second->size());
treeMap.erase(sortKey);
}
}
RowData* removeLast()
{
if (treeMap.empty()) {
return nullptr;
}
auto last = --treeMap.end();
auto collection = last->second;
RowData* lastElement = nullptr;
if (collection != nullptr && !collection->empty()) {
lastElement = collection->back();
collection->pop_back();
currentTopNum--;
if (collection->empty()) {
treeMap.erase(last->first);
}
} else {
if (!collection) {
return nullptr;
}
lastElement = getLastElement(*collection);
removeLast(collection, lastElement, last->first);
}
return lastElement;
}
void removeLast(std::vector<RowData*>* collection, RowData* lastElement, RowData* element)
{
if (lastElement == nullptr) {
return;
}
for (auto it = collection->begin(); it != collection->end(); ++it) {
if (*it == lastElement) {
collection->erase(it);
currentTopNum--;
break;
}
}
if (collection->empty()) {
treeMap.erase(element);
}
}
void putAll(RowData* sortKey, std::vector<RowData*>* values)
{
auto it = treeMap.find(sortKey);
if (it != treeMap.end()) {
currentTopNum -= static_cast<int>(it->second->size());
it->second = values;
} else {
treeMap[sortKey] = values;
}
currentTopNum += static_cast<int>(values->size());
}
auto lastEntry()
{
return treeMap.rbegin();
}
int getCurrentTopNum()
{
return currentTopNum;
}
RowData* lastElement()
{
if (treeMap.empty()) {
return nullptr;
}
auto lastEntry = --treeMap.end();
std::vector<RowData*>* collection = lastEntry->second;
if (collection != nullptr) {
return getLastElement(*collection);
}
return nullptr;
}
RowData* getLastElement(const std::vector<RowData*>& collection)
{
if (!collection.empty()) {
return collection.back();
}
return nullptr;
}
bool checkSortKeyInBufferRange(RowData* sortKey, long topNum)
{
auto worstEntry = lastEntry();
if (worstEntry == treeMap.rend()) {
return true;
} else {
RowData* worstKey = worstEntry->first;
bool compare = DescendingComparator{}(sortKey, worstKey);
return compare || (currentTopNum < topNum);
}
}
private:
int currentTopNum = 0;
std::map<RowData*, std::vector<RowData*>*, DescendingComparator> treeMap;
};
#endif